In [17]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import os
import pywt
import collections

# --- Configuration ---
CONFIG = {
    "VIDEO_PATH": "video.mp4",
    "IMG_SIZE": 256,
    "MODEL_PATH": "models/specular_removal_model_epoch_50.pth",
    "NUM_CONTEXT_FRAMES": 4,
    "FRAME_SKIP": 3,
    "SPECULAR_THRESHOLD": 0.5,
    "WAVELET": "haar",
    "LEVEL": 2,
    "MIN_MASK_SIZE": 10,
    "MAX_MASK_SIZE": 50,
    "TRAIN_EPOCHS": 50,
    "LEARNING_RATE": 1e-4,
    "INITIAL_KEYFRAME_SECONDS": 5,
    "BATCH_SIZE": 4,
}

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- FFC Implementation (Corrected) ---
class FFC(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride=1, padding=0):
        super().__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        self.ratio_gin, self.ratio_gout = ratio_gin, ratio_gout
        
        self.in_local = int(self.in_channels * (1 - ratio_gin))
        self.in_global = self.in_channels - self.in_local
        self.out_local = int(self.out_channels * (1 - ratio_gout))
        self.out_global = self.out_channels - self.out_local

        if self.in_local > 0 and self.out_local > 0:
            self.conv_l = nn.Conv2d(self.in_local, self.out_local, kernel_size, stride, padding, bias=False)
        else:
            self.conv_l = None
        
        if self.in_global > 0 and self.out_global > 0:
            self.conv_g = nn.Conv2d(self.in_global * 2, self.out_global * 2, kernel_size=1, bias=False)
        else:
            self.conv_g = None

    def forward(self, x):
        batch_size, _, height, width = x.shape
        if self.in_local > 0 and self.in_global > 0:
            x_l, x_g = torch.split(x, [self.in_local, self.in_global], dim=1)
        elif self.in_local > 0:
            x_l, x_g = x, None
        else:
            x_l, x_g = None, x

        out_l = None
        if self.conv_l is not None:
            if x_l is not None:
                out_l = self.conv_l(x_l)
            elif self.out_local > 0:
                out_l = torch.zeros(batch_size, self.out_local, height, width, device=x.device)

        out_g = None
        if x_g is not None and self.conv_g is not None:
            fft_g = torch.fft.rfft2(x_g, norm='ortho')
            fft_g_real_imag = torch.cat([fft_g.real, fft_g.imag], dim=1)
            ffc_out_real_imag = self.conv_g(fft_g_real_imag)
            ffc_out_real, ffc_out_imag = torch.split(ffc_out_real_imag, self.out_global, dim=1)
            fft_g = torch.complex(ffc_out_real, ffc_out_imag)
            out_g = torch.fft.irfft2(fft_g, s=(height, width), norm='ortho')
        elif self.out_global > 0:
            out_g = torch.zeros(batch_size, self.out_global, height, width, device=x.device)

        if out_l is not None and out_g is not None:
            return torch.cat([out_l, out_g], dim=1)
        elif out_l is not None:
            return out_l
        elif out_g is not None:
            return out_g
        else:
            return torch.zeros(batch_size, self.out_channels, height, width, device=x.device)

class FFC_BN_ACT(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 ratio_gin=0.75, ratio_gout=0.75, stride=1, padding=0):
        super().__init__()
        self.ffc = FFC(in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.ffc(x)
        x = self.bn(x)
        x = self.act(x)
        return x

# --- Model Definition (Corrected Architecture) ---

class LaMaUNet(nn.Module):
    def __init__(self, n_channels=18, n_classes=3, bilinear=True):
        super().__init__()
        ratio = 0.75

        # --- Encoder ---
        self.inc = nn.Sequential(
            FFC_BN_ACT(n_channels, 32, kernel_size=3, padding=1, ratio_gin=0, ratio_gout=ratio),
            FFC_BN_ACT(32, 32, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio)
        )
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            FFC_BN_ACT(32, 64, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio)
        )
        
        # --- Bottleneck ---
        self.bottleneck = nn.Sequential(
            FFC_BN_ACT(64, 128, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
            FFC_BN_ACT(128, 64, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio)
        )
        
        # --- Decoder ---
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear' if bilinear else 'nearest', align_corners=True if bilinear else None)
        # Input channels: 64 from upsample (from bottleneck) + 32 from skip connection (x1) = 96
        self.conv_up1 = nn.Sequential(
            FFC_BN_ACT(96, 32, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio),
            FFC_BN_ACT(32, 32, kernel_size=3, padding=1, ratio_gin=ratio, ratio_gout=ratio)
        )

        # --- Final Output Layer ---
        self.outc = nn.Conv2d(32, n_classes, kernel_size=1)

    def forward(self, x):
        # --- Encoder Path ---
        x1 = self.inc(x)           # Shape: [B, 32, 256, 256]
        x2 = self.down1(x1)         # Shape: [B, 64, 128, 128]
        
        # --- Bottleneck ---
        x_bottle = self.bottleneck(x2) # Shape: [B, 64, 128, 128]

        # --- Decoder Path ---
        u1 = self.up1(x_bottle)     # Shape: [B, 64, 256, 256]
        
        # Skip Connection: Concatenate with x1 from the same level
        c1 = torch.cat([u1, x1], dim=1) # Shape: [B, 64+32=96, 256, 256]
        
        # Up-convolution
        d1 = self.conv_up1(c1)      # Shape: [B, 32, 256, 256]

        # Final output
        logits = self.outc(d1)      # Shape: [B, 3, 256, 256]

        return torch.sigmoid(logits)

# ... The rest of the code (VideoFrameProcessor, main, etc.) remains unchanged ...
def calculate_specular_score(frame):
    """Calculates a specular score based on the area and intensity of specular regions."""
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    v = hsv[:, :, 2]
    specular_mask = (v > 200) & (hsv[:, :, 1] < 40)
    if np.sum(specular_mask) == 0:
        return 0.0
    specular_area = np.sum(specular_mask) / (frame.shape[0] * frame.shape[1])
    specular_intensity = np.mean(v[specular_mask]) / 255.0
    return specular_area * specular_intensity

def get_sharpness(image):
    """Calculates image sharpness using the variance of the Laplacian."""
    if len(image.shape) == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    return cv2.Laplacian(image, cv2.CV_64F).var()

class VideoFrameProcessor:
    def __init__(self, config):
        self.config = config
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.model = LaMaUNet(n_channels=18, n_classes=3).to(DEVICE)
        if os.path.exists(config["MODEL_PATH"]):
            self.model.load_state_dict(torch.load(config["MODEL_PATH"], map_location=DEVICE))
            print(f"Loaded model from {config['MODEL_PATH']}")
        else:
            print(f"Model file {config['MODEL_PATH']} not found. Training a new model...")
            self._train_model()
        self.model.eval()

    def _train_model(self):
        from torch.utils.data import Dataset, DataLoader
        class TempDataset(Dataset):
            def __init__(self, frames, keyframe, config):
                self.frames = frames
                self.keyframe = keyframe
                self.config = config
                self.transform = transforms.Compose([transforms.ToTensor()])
            def __len__(self): return len(self.frames)
            def __getitem__(self, idx):
                frame = cv2.resize(self.frames[idx], (self.config["IMG_SIZE"], self.config["IMG_SIZE"]))
                target_tensor = self.transform(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
                
                keyframe_tensor = self.transform(Image.fromarray(cv2.cvtColor(cv2.resize(self.keyframe, (self.config["IMG_SIZE"], self.config["IMG_SIZE"])), cv2.COLOR_BGR2RGB))) * 0.5
                
                context_tensors = []
                for i in range(self.config["NUM_CONTEXT_FRAMES"]):
                    context_idx = max(0, idx - ((i + 1) * self.config["FRAME_SKIP"]))
                    ctx_frame = cv2.resize(self.frames[context_idx], (self.config["IMG_SIZE"], self.config["IMG_SIZE"]))
                    ctx_tensor = self.transform(Image.fromarray(cv2.cvtColor(ctx_frame, cv2.COLOR_BGR2RGB)))
                    context_tensors.append(ctx_tensor)
                
                masked_target_tensor = target_tensor.clone()
                if np.random.random() > 0.5:
                    mask_h, mask_w = np.random.randint(10, 30), np.random.randint(10, 30)
                    start_h, start_w = np.random.randint(0, self.config["IMG_SIZE"] - mask_h), np.random.randint(0, self.config["IMG_SIZE"] - mask_w)
                    masked_target_tensor[:, start_h:start_h+mask_h, start_w:start_w+mask_w] = 0
                
                model_input = torch.cat(context_tensors[::-1] + [masked_target_tensor, keyframe_tensor], dim=0)
                ground_truth = target_tensor
                return model_input, ground_truth

        cap = cv2.VideoCapture(self.config["VIDEO_PATH"])
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret: break
            frames.append(frame)
        cap.release()

        best_score = float('inf')
        best_keyframe = frames[0]
        for frame in frames[:int(self.config["INITIAL_KEYFRAME_SECONDS"] * 30)]:
            score = calculate_specular_score(frame)
            if score < best_score:
                best_score = score
                best_keyframe = frame
        
        dataset = TempDataset(frames, best_keyframe, self.config)
        loader = DataLoader(dataset, batch_size=self.config["BATCH_SIZE"], shuffle=True)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=self.config["LEARNING_RATE"])
        
        self.model.train()
        print("Starting fallback training...")
        for epoch in range(self.config["TRAIN_EPOCHS"]):
            epoch_loss = 0.0
            for model_input, ground_truth in loader:
                model_input, ground_truth = model_input.to(DEVICE), ground_truth.to(DEVICE)
                optimizer.zero_grad()
                outputs = self.model(model_input)
                loss = criterion(outputs, ground_truth)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            print(f"Training epoch {epoch+1}/{self.config['TRAIN_EPOCHS']} completed. Average Loss: {epoch_loss / len(loader):.6f}")
            
        os.makedirs("models", exist_ok=True)
        torch.save(self.model.state_dict(), self.config["MODEL_PATH"])
        print(f"Trained and saved model to {self.config['MODEL_PATH']}")
        self.model.eval()

    def _detect_specular_regions(self, frame):
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        coeffs = pywt.wavedec2(gray, self.config["WAVELET"], level=self.config["LEVEL"])
        details = []
        for detail_coeffs in coeffs[1:]:
            details.extend([np.abs(detail_coeffs[i]) for i in range(3)])
        
        detail_features = np.stack([cv2.resize(d, (gray.shape[1], gray.shape[0])) for d in details])
        detail_features = np.max(detail_features, axis=0)
        
        normalized_features = (detail_features - np.min(detail_features)) / (np.max(detail_features) - np.min(detail_features) + 1e-6)
        wavelet_mask = normalized_features > self.config["SPECULAR_THRESHOLD"]
        
        hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
        v = hsv[:, :, 2]
        high_v = v > 200
        low_s = hsv[:,:,1] < 40
        combined_mask = np.logical_or(np.logical_and(high_v, low_s), wavelet_mask)
        
        kernel = np.ones((3, 3), np.uint8)
        mask = cv2.morphologyEx(combined_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        return mask

    def _mask_with_black_boxes(self, frame, mask):
        h, w = frame.shape[:2]
        black_box_mask = np.zeros((h, w), dtype=np.uint8)
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        for contour in contours:
            if cv2.contourArea(contour) > 10:
                x, y, w_box, h_box = cv2.boundingRect(contour)
                box_size = np.random.randint(self.config["MIN_MASK_SIZE"], self.config["MAX_MASK_SIZE"])
                x_start = max(0, x + w_box // 2 - box_size // 2)
                y_start = max(0, y + h_box // 2 - box_size // 2)
                x_end = min(w, x_start + box_size)
                y_end = min(h, y_start + box_size)
                cv2.rectangle(black_box_mask, (x_start, y_start), (x_end, y_end), 1, -1)
        return black_box_mask

    def process_frame(self, frame, frame_idx_in_buffer, frame_buffer, keyframe):
        orig_h, orig_w = frame.shape[:2]
        frame_resized = cv2.resize(frame, (self.config["IMG_SIZE"], self.config["IMG_SIZE"]))
        keyframe_resized = cv2.resize(keyframe, (self.config["IMG_SIZE"], self.config["IMG_SIZE"]))

        frame_tensor = self.transform(Image.fromarray(cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB))).unsqueeze(0).to(DEVICE)
        keyframe_tensor = self.transform(Image.fromarray(cv2.cvtColor(keyframe_resized, cv2.COLOR_BGR2RGB))).unsqueeze(0).to(DEVICE) * 0.5

        specular_mask = self._detect_specular_regions(frame_resized)
        black_box_mask = self._mask_with_black_boxes(frame_resized, specular_mask)
        mask_tensor = torch.from_numpy(black_box_mask).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
        masked_target_tensor = frame_tensor * (1 - mask_tensor)

        context_tensors = []
        for i in range(self.config["NUM_CONTEXT_FRAMES"]):
            context_idx = max(0, frame_idx_in_buffer - ((i + 1) * self.config["FRAME_SKIP"]))
            ctx_frame = cv2.resize(frame_buffer[context_idx], (self.config["IMG_SIZE"], self.config["IMG_SIZE"]))
            ctx_tensor = self.transform(Image.fromarray(cv2.cvtColor(ctx_frame, cv2.COLOR_BGR2RGB))).unsqueeze(0).to(DEVICE)
            context_tensors.append(ctx_tensor)
        
        input_tensor = torch.cat(context_tensors[::-1] + [masked_target_tensor, keyframe_tensor], dim=1)

        with torch.no_grad():
            output_tensor = self.model(input_tensor)

        output_img_resized = (output_tensor[0].cpu().permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
        output_img_resized = cv2.cvtColor(output_img_resized, cv2.COLOR_RGB2BGR)
        
        output_score = calculate_specular_score(output_img_resized)
        keyframe_score = calculate_specular_score(keyframe_resized)
        
        if output_score < keyframe_score * 0.95:
            output_sharpness = get_sharpness(output_img_resized)
            keyframe_sharpness = get_sharpness(keyframe_resized)
            
            if output_sharpness >= keyframe_sharpness * 0.8:
                keyframe = cv2.resize(output_img_resized, (orig_w, orig_h))
                print(f"  -> Keyframe updated! New score: {output_score:.4f}, Sharpness: {output_sharpness:.2f}")
            else:
                print(f"  -> Rejected keyframe update due to low sharpness ({output_sharpness:.2f} vs {keyframe_sharpness:.2f})")
        
        final_output_img = cv2.resize(output_img_resized, (orig_w, orig_h))
        return final_output_img, keyframe, black_box_mask

def main():
    processor = VideoFrameProcessor(CONFIG)

    cap = cv2.VideoCapture(CONFIG["VIDEO_PATH"])
    if not cap.isOpened():
        print(f"Error: Could not open video file at {CONFIG['VIDEO_PATH']}")
        return
        
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0

    print("Scanning initial frames for best keyframe...")
    initial_frames = []
    num_initial_frames = int(CONFIG["INITIAL_KEYFRAME_SECONDS"] * fps)
    for _ in range(num_initial_frames):
        ret, frame = cap.read()
        if not ret: break
        initial_frames.append(frame)

    if not initial_frames:
        print("Error: Could not read any frames from video.")
        cap.release()
        return

    best_keyframe = initial_frames[0].copy()
    best_score = float('inf')
    for frame in initial_frames:
        score = calculate_specular_score(frame)
        if score < best_score:
            best_score = score
            best_keyframe = frame.copy()
    print(f"Initial keyframe found with specular score: {best_score:.4f}")
    
    max_buffer_size = CONFIG["NUM_CONTEXT_FRAMES"] * CONFIG["FRAME_SKIP"] + 5
    frame_buffer = collections.deque(initial_frames, maxlen=max_buffer_size)
    
    os.makedirs("inpainted_output", exist_ok=True)
    keyframe = best_keyframe
    
    # Process the frames already in the buffer first
    for i, frame in enumerate(initial_frames):
        output_img, keyframe, _ = processor.process_frame(frame, i, frame_buffer, keyframe)
        output_path = f"inpainted_output/frame_{i:06d}.png"
        
        # ✅✅✅ CORRECTED LINE ✅✅✅
        cv2.imwrite(output_path, output_img)
        
        print(f"Processed initial frame {i+1}/{len(initial_frames)}")
        
    # Process the rest of the video
    frame_idx = len(initial_frames)
    while True:
        ret, current_frame = cap.read()
        if not ret:
            break
        
        frame_buffer.append(current_frame)
        
        output_img, keyframe, _ = processor.process_frame(current_frame, len(frame_buffer) - 1, frame_buffer, keyframe)
        
        output_path = f"inpainted_output/frame_{frame_idx:06d}.png"
        cv2.imwrite(output_path, output_img)
        print(f"Processed frame {frame_idx}...")
        frame_idx += 1
        
    cap.release()
    print("Processing complete. Output saved to 'inpainted_output' directory.")

if __name__ == "__main__":
    main()

Using device: cuda
Model file models/specular_removal_model_epoch_50.pth not found. Training a new model...
Starting fallback training...
Training epoch 1/50 completed. Average Loss: 0.153508
Training epoch 2/50 completed. Average Loss: 0.072186
Training epoch 3/50 completed. Average Loss: 0.034875
Training epoch 4/50 completed. Average Loss: 0.017543
Training epoch 5/50 completed. Average Loss: 0.009831
Training epoch 6/50 completed. Average Loss: 0.006028
Training epoch 7/50 completed. Average Loss: 0.003942
Training epoch 8/50 completed. Average Loss: 0.002702
Training epoch 9/50 completed. Average Loss: 0.001942
Training epoch 10/50 completed. Average Loss: 0.001452
Training epoch 11/50 completed. Average Loss: 0.001126
Training epoch 12/50 completed. Average Loss: 0.000904
Training epoch 13/50 completed. Average Loss: 0.000741
Training epoch 14/50 completed. Average Loss: 0.000626
Training epoch 15/50 completed. Average Loss: 0.000535
Training epoch 16/50 completed. Average Loss: 

IndexError: deque index out of range