# Unified SR Model Training on Kaggle

This notebook trains the Unified SR Model using the **STARE** (Medical), **Tuberculosis** (Medical), and **SpaceNet 2** (Satellite) datasets.

## Workflow
1.  **Setup**: Copy code and install dependencies.
2.  **Data Preparation**: 
    *   We will read the raw images from Kaggle Input.
    *   We will generate Low-Resolution (LR) images and create a Validation split for the Medical dataset.
    *   All processed data will be stored in `/kaggle/working/data`.
3.  **Training**: Train the model using the processed data.

In [None]:
import os
import shutil
import cv2
import numpy as np
import random
from glob import glob
import torch

# Check GPU
print(f"GPU Available: {torch.cuda.is_available()}")

## 1. Setup Paths
**IMPORTANT**: Verify these paths match your Kaggle Input structure.

In [None]:
# INPUT PATHS (Adjust if needed)
# STARE Dataset
MEDICAL_RAW_PATH = '/kaggle/input/stare-dataset' 

# Tuberculosis Dataset
TB_RAW_PATH = '/kaggle/input/tuberculosis-chest-xrays-shenzhen'

# SpaceNet 2 Dataset
SATELLITE_RAW_PATH = '/kaggle/input/spacenet-2-paris-buildings' 

# OUTPUT PATH (Working Directory)
BASE_DATA_DIR = '/kaggle/working/data'
CODE_DIR = '/kaggle/input/unified-sr-code' # Where you uploaded src/ and train.py

In [None]:
# Copy Code to Working Directory
if os.path.exists(CODE_DIR):
    if os.path.exists('/kaggle/working/src'):
        shutil.rmtree('/kaggle/working/src')
    shutil.copytree(os.path.join(CODE_DIR, 'src'), '/kaggle/working/src')
    shutil.copy(os.path.join(CODE_DIR, 'train.py'), '/kaggle/working/train.py')
    print("Code copied to /kaggle/working")
else:
    print("WARNING: Code directory not found. Make sure you added the 'unified-sr-code' dataset.")

## 2. Data Preprocessing
We need to:
1.  Resize HR images to 128x128.
2.  Generate LR images (32x32).
3.  Split Medical data into Train/Val (80/20).

In [None]:
def preprocess_image(img_path, save_hr_path, save_lr_path, scale=4, size=(128, 128)):
    img = cv2.imread(img_path)
    if img is None: return False
    
    # Resize HR
    hr_img = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
    
    # Generate LR
    h, w = hr_img.shape[:2]
    lr_img = cv2.resize(hr_img, (w // scale, h // scale), interpolation=cv2.INTER_CUBIC)
    
    # Save
    basename = os.path.basename(img_path)
    cv2.imwrite(os.path.join(save_hr_path, f"resized_{basename}"), hr_img)
    cv2.imwrite(os.path.join(save_lr_path, basename), lr_img)
    return True

def process_dataset(raw_path, out_train_dir, out_val_dir, is_medical=False, split_ratio=0.8):
    # Find all images
    exts = ('*.ppm', '*.png', '*.jpg', '*.tif')
    files = []
    for ext in exts:
        files.extend(glob(os.path.join(raw_path, '**', ext), recursive=True))
    
    print(f"Found {len(files)} images in {raw_path}")
    if not files: return
    
    random.shuffle(files)
    
    # Determine Split
    if is_medical:
        # Medical: Split raw files into Train/Val
        split_idx = int(len(files) * split_ratio)
        train_files = files[:split_idx]
        val_files = files[split_idx:]
    else:
        # Satellite
        split_idx = int(len(files) * split_ratio)
        train_files = files[:split_idx]
        val_files = files[split_idx:]

    # Process Train
    print(f"Processing {len(train_files)} Train images...")
    os.makedirs(os.path.join(out_train_dir, 'hr'), exist_ok=True)
    os.makedirs(os.path.join(out_train_dir, 'lr'), exist_ok=True)
    for f in train_files:
        preprocess_image(f, os.path.join(out_train_dir, 'hr'), os.path.join(out_train_dir, 'lr'))

    # Process Val
    print(f"Processing {len(val_files)} Val images...")
    os.makedirs(os.path.join(out_val_dir, 'hr'), exist_ok=True)
    os.makedirs(os.path.join(out_val_dir, 'lr'), exist_ok=True)
    for f in val_files:
        preprocess_image(f, os.path.join(out_val_dir, 'hr'), os.path.join(out_val_dir, 'lr'))

In [None]:
# Run Processing
# Medical (STARE)
if os.path.exists(MEDICAL_RAW_PATH):
    process_dataset(
        MEDICAL_RAW_PATH, 
        os.path.join(BASE_DATA_DIR, 'medical', 'train'), 
        os.path.join(BASE_DATA_DIR, 'medical', 'val'), 
        is_medical=True
    )

# Medical (TB)
if os.path.exists(TB_RAW_PATH):
    print("Processing TB Dataset...")
    process_dataset(
        TB_RAW_PATH, 
        os.path.join(BASE_DATA_DIR, 'medical', 'train'), 
        os.path.join(BASE_DATA_DIR, 'medical', 'val'), 
        is_medical=True
    )

# Satellite
process_dataset(
    SATELLITE_RAW_PATH, 
    os.path.join(BASE_DATA_DIR, 'satelitte', 'train'), 
    os.path.join(BASE_DATA_DIR, 'satelitte', 'val'), 
    is_medical=False
)

## 3. Training

In [None]:
!python train.py \
    --epochs 50 \
    --batch_size 16 \
    --lr 0.0001 \
    --scale 4 \
    --medical_data {BASE_DATA_DIR}/medical \
    --satellite_data {BASE_DATA_DIR}/satelitte \
    --save_dir checkpoints

In [None]:
# Zip the checkpoints folder for easier download
!zip -r checkpoints.zip checkpoints

from IPython.display import FileLink
import os

if os.path.exists('checkpoints.zip'):
    display(FileLink(r'checkpoints.zip'))
else:
    print("Error: checkpoints.zip not found. Training might have failed.")