In [2]:
import torch
import torch._dynamo
# Suppress errors from torch.compile to fall back to eager mode if needed
torch._dynamo.config.suppress_errors = True

import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import os
from PIL import Image, ImageTk
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score
import tkinter as tk
from tkinter import filedialog, ttk, messagebox
import threading
from pycocotools.coco import COCO
import urllib.request
import zipfile
import sys

def remove_orig_mod_prefix(state_dict):
    """
    Remove the '_orig_mod.' prefix from keys in a state dict.
    If both a prefixed and unprefixed key exist for the same parameter,
    the cleaned (unprefixed) version from the prefixed key is used.
    """
    cleaned = {}
    # First, add all keys that have the prefix (cleaned)
    for key, value in state_dict.items():
        if key.startswith("_orig_mod."):
            cleaned[key[len("_orig_mod."):]] = value
    # Then, add keys that do not have the prefix only if they don't conflict.
    for key, value in state_dict.items():
        if not key.startswith("_orig_mod."):
            if key not in cleaned:
                cleaned[key] = value
    return cleaned

def download_dataset():
    """Downloads the COCO dataset for semantic segmentation and colorization."""
    if not os.path.exists("datasets"):
        os.makedirs("datasets")
    if not os.path.exists("datasets/val2017"):
        print("Downloading COCO dataset...")
        url = "http://images.cocodataset.org/zips/val2017.zip"
        urllib.request.urlretrieve(url, "val2017.zip")
        with zipfile.ZipFile("val2017.zip", 'r') as zip_ref:
            zip_ref.extractall("datasets")
        os.remove("val2017.zip")
        print("Dataset downloaded and extracted successfully!")
    else:
        print("Dataset already exists!")

class ColorizationDataset(Dataset):
    def __init__(self, image_dir, transform=None, size=(256, 256)):
        self.image_dir = image_dir
        # Reduce dataset size for faster training
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')][:500]
        self.transform = transform
        self.size = size
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        image = image.resize(self.size, Image.Resampling.LANCZOS)
        image_np = np.array(image)
        
        # Convert to LAB color space
        lab_image = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
        
        # Normalize L channel to [-1, 1]
        l_channel = lab_image[:, :, 0].astype(np.float32) / 50.0 - 1.0
        
        # Normalize ab channels to [-1, 1]
        ab_channels = lab_image[:, :, 1:].astype(np.float32)
        ab_channels = (ab_channels - 128.0) / 128.0
        
        # Create a binary mask using edge detection and thresholding
        gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
        edges = cv2.Canny(gray, 100, 200)
        kernel = np.ones((5,5), np.uint8)
        dilated = cv2.dilate(edges, kernel, iterations=2)
        mask = dilated > 0
        
        if self.transform:
            # For L: convert normalized values back to [0,255] for PIL conversion, then re-normalize
            l_img = Image.fromarray(((l_channel + 1.0) * 50.0).astype(np.uint8))
            l_tensor = self.transform(l_img)
            l_tensor = l_tensor * 2.0 - 1.0
            
            # For ab channels: convert with shape (H,W,2) to tensor
            ab_tensor = torch.from_numpy(ab_channels.transpose((2, 0, 1))).float()
            mask_tensor = torch.from_numpy(mask.astype(np.float32))
            
            return l_tensor, ab_tensor, mask_tensor
        else:
            return (torch.from_numpy(l_channel).unsqueeze(0),
                    torch.from_numpy(ab_channels.transpose((2, 0, 1))),
                    torch.from_numpy(mask.astype(np.float32)))

class SegColorizer(nn.Module):
    def __init__(self):
        super(SegColorizer, self).__init__()
        
        # Improved encoder with residual connections
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.enc2 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        self.enc3 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Segmentation branch
        self.seg_branch = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        )
        
        # Colorization branch with improved upsampling
        self.dec3 = nn.Sequential(
            nn.Conv2d(257, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 2, kernel_size=1),
            nn.Tanh()
        )
        
    def forward(self, x, mask=None):
        # Encoding
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        
        # Segmentation
        seg_output = self.seg_branch(enc3)
        
        if mask is not None:
            seg_mask = mask.unsqueeze(1)
        else:
            seg_mask = torch.sigmoid(seg_output)
        
        # Combine features with segmentation mask
        seg_small = nn.functional.interpolate(seg_mask, size=enc3.shape[2:], mode='bilinear', align_corners=True)
        combined = torch.cat([enc3, seg_small], dim=1)
        
        # Decoding with skip connections
        dec3 = self.dec3(combined)
        dec2 = self.dec2(dec3)
        color_output = self.dec1(dec2)
        
        return color_output, seg_output

def train_model(model, train_loader, val_loader, num_epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if device.type == 'cpu':
        print("WARNING: GPU not detected. Training will be slow!")
    else:
        print(f"Using GPU: {torch.cuda.get_device_name()}")
        
    model = model.to(device)
    # For Windows systems, disable multiprocessing in DataLoader
    num_workers = 0 if sys.platform.startswith('win') else 2

    # Use torch.compile for potential speedup; falls back to eager mode if needed.
    model = torch.compile(model)
    
    criterion_color = nn.MSELoss().to(device)
    criterion_seg = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([2.0]).to(device)).to(device)
    
    optimizer = optim.AdamW(model.parameters(), lr=0.002, weight_decay=0.01, betas=(0.9, 0.999))
    
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    scaler = torch.amp.GradScaler(device='cuda')
    
    best_loss = float('inf')
    metrics = {'train_loss': [], 'val_loss': [], 'precision': [], 'recall': []}
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch_idx, (l_channel, ab_channels, mask) in enumerate(pbar):
            l_channel = l_channel.to(device, non_blocking=True)
            ab_channels = ab_channels.to(device, non_blocking=True)
            mask = mask.to(device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)
            
            with torch.amp.autocast(device_type='cuda'):
                color_output, seg_output = model(l_channel, mask)
                loss_color = criterion_color(color_output, ab_channels)
                loss_seg = criterion_seg(seg_output, mask.unsqueeze(1).float())
                loss = 0.7 * loss_color + 0.3 * loss_seg
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item()
            
            if batch_idx % 5 == 0:
                pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.6f}'})
        
        scheduler.step()
        
        model.eval()
        val_loss = 0
        all_preds = []
        all_masks = []
        
        with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
            for l_channel, ab_channels, mask in val_loader:
                l_channel = l_channel.to(device, non_blocking=True)
                ab_channels = ab_channels.to(device, non_blocking=True)
                mask = mask.to(device, non_blocking=True)
                
                color_output, seg_output = model(l_channel)
                loss_color = criterion_color(color_output, ab_channels)
                loss_seg = criterion_seg(seg_output, mask.unsqueeze(1).float())
                val_loss += (0.7 * loss_color + 0.3 * loss_seg).item()
                
                pred_masks = (torch.sigmoid(seg_output) > 0.5).float()
                all_preds.extend(pred_masks.cpu().numpy().flatten())
                all_masks.extend((mask.cpu().numpy() > 0.5).flatten())
        
        val_loss /= len(val_loader)
        precision = precision_score(all_masks, all_preds, zero_division=1)
        recall = recall_score(all_masks, all_preds, zero_division=1)
        
        metrics['train_loss'].append(train_loss / len(train_loader))
        metrics['val_loss'].append(val_loss)
        metrics['precision'].append(precision)
        metrics['recall'].append(recall)
        
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}')
        print(f'Val Loss: {val_loss:.4f}')
        print(f'Precision: {precision:.4f}')
        print(f'Recall: {recall:.4f}\n')
        
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, 'best_model.pth')
            print(f'Saved new best model with loss: {best_loss:.4f}')
    
    return metrics

def enhance_saturation(image, factor=1.5):
    """
    Enhance the saturation of a given RGB image by the specified factor.
    """
    # Convert image to HSV
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    # Increase the saturation channel
    hsv[:, :, 1] = np.clip(hsv[:, :, 1].astype(np.float32) * factor, 0, 255).astype(np.uint8)
    # Convert back to RGB
    enhanced = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
    return enhanced

class InteractiveColorizationGUI:
    def __init__(self, model_path='best_model.pth'):
        self.window = tk.Tk()
        self.window.title("Interactive Image Colorization")
        self.window.geometry("1200x800")
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = SegColorizer().to(self.device)
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path, map_location=self.device)
            state_dict = remove_orig_mod_prefix(checkpoint['model_state_dict'])
            self.model.load_state_dict(state_dict)
            self.model.eval()
            
        self.setup_gui()
        
    def setup_gui(self):
        control_panel = ttk.Frame(self.window, padding="10")
        control_panel.grid(row=0, column=0, sticky="nsew")
        
        ttk.Button(control_panel, text="Load Image", command=self.load_image).grid(row=0, column=0, pady=5)
        ttk.Label(control_panel, text="Select Regions to Colorize:").grid(row=1, column=0, pady=5)
        
        self.region_vars = {
            'Foreground': tk.BooleanVar(value=True),
            'Background': tk.BooleanVar(value=True)
        }
        
        row = 2
        for region, var in self.region_vars.items():
            ttk.Checkbutton(control_panel, text=region, variable=var, command=self.update_preview).grid(row=row, column=0)
            row += 1
            
        ttk.Button(control_panel, text="Colorize", command=self.colorize_image).grid(row=row, column=0, pady=10)
        
        self.canvas = tk.Canvas(self.window, width=800, height=600)
        self.canvas.grid(row=0, column=1, padx=10, pady=10)
        
        self.progress = ttk.Progressbar(self.window, orient="horizontal", length=200, mode="determinate")
        self.progress.grid(row=1, column=1, padx=10, pady=5)
        
        self.window.grid_columnconfigure(1, weight=1)
        self.window.grid_rowconfigure(0, weight=1)
        
        self.original_image = None
        self.processed_image = None
        self.mask = None
        
    def load_image(self):
        try:
            file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg *.jpeg *.png *.bmp")])
            if file_path:
                image = Image.open(file_path).convert('RGB')
                image = image.resize((256, 256), Image.Resampling.LANCZOS)
                self.original_image = image
                self.display_image(image)
                self.generate_segmentation()
        except Exception as e:
            messagebox.showerror("Error", f"Error loading image: {str(e)}")
            
    def generate_segmentation(self):
        if self.original_image is None:
            return
            
        try:
            image_np = np.array(self.original_image)
            lab_image = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
            # Mimic conversion for L channel
            l_channel = lab_image[:, :, 0].astype(np.float32) / 50.0 - 1.0
            l_tensor = torch.from_numpy(l_channel).unsqueeze(0).unsqueeze(0).to(self.device)
            with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
                _, seg_output = self.model(l_tensor)
                self.mask = torch.sigmoid(seg_output).cpu().numpy()[0, 0]
            self.update_preview()
        except Exception as e:
            messagebox.showerror("Error", f"Error generating segmentation: {str(e)}")
            
    def update_preview(self):
        if self.mask is None:
            return
        try:
            # Use continuous threshold for preview
            preview_mask = (self.mask > 0.5)
            preview = np.array(self.original_image)
            gray = cv2.cvtColor(preview, cv2.COLOR_RGB2GRAY)
            gray_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
            preview[~preview_mask] = gray_rgb[~preview_mask]
            self.display_image(Image.fromarray(preview))
        except Exception as e:
            messagebox.showerror("Error", f"Error updating preview: {str(e)}")
            
    def colorize_image(self):
        if self.original_image is None:
            return
        try:
            image_np = np.array(self.original_image)
            lab_image = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
            l_channel = lab_image[:, :, 0].astype(np.float32) / 50.0 - 1.0
            l_tensor = torch.from_numpy(l_channel).unsqueeze(0).unsqueeze(0).to(self.device)
            
            # Get continuous mask probabilities for blending
            with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
                _, seg_output = self.model(l_tensor)
                mask_prob = torch.sigmoid(seg_output).cpu().numpy()[0, 0]
            
            # Determine blending weight based on user options
            if self.region_vars['Foreground'].get() and not self.region_vars['Background'].get():
                blend = mask_prob
            elif self.region_vars['Background'].get() and not self.region_vars['Foreground'].get():
                blend = 1 - mask_prob
            elif self.region_vars['Foreground'].get() and self.region_vars['Background'].get():
                blend = np.ones_like(mask_prob)
            else:
                blend = np.zeros_like(mask_prob)
            
            # Use binary mask (threshold 0.5) for ab channel generation
            binary_mask = (mask_prob > 0.5).astype(np.float32)
            mask_tensor = torch.from_numpy(binary_mask).unsqueeze(0).to(self.device)
            with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
                color_output, _ = self.model(l_tensor, mask_tensor)
                ab_channels = color_output.cpu().numpy()[0]
                ab_channels = ab_channels.transpose(1, 2, 0)
                ab_channels = ab_channels * 128.0 + 128.0
            colorized_lab = np.concatenate([lab_image[:, :, 0:1], ab_channels], axis=2)
            colorized_rgb = cv2.cvtColor(colorized_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
            gray_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
            gray_rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
            blend = blend[..., np.newaxis]
            final_result = blend * colorized_rgb + (1 - blend) * gray_rgb

            # --- Enhancement Step: Increase Saturation ---
            final_result = final_result.astype(np.uint8)
            final_result = enhance_saturation(final_result, factor=1.5)
            # -------------------------------------------------

            self.display_image(Image.fromarray(final_result))
            
            if messagebox.askyesno("Save", "Would you like to save the colorized image?"):
                save_path = filedialog.asksaveasfilename(defaultextension=".png",
                                                        filetypes=[("PNG files", "*.png"),
                                                                  ("JPEG files", "*.jpg"),
                                                                  ("All files", "*.*")])
                if save_path:
                    Image.fromarray(final_result).save(save_path)
                    
        except Exception as e:
            messagebox.showerror("Error", f"Error during colorization: {str(e)}")
            
    def display_image(self, image):
        display_size = (800, 600)
        image.thumbnail(display_size, Image.Resampling.LANCZOS)
        photo = ImageTk.PhotoImage(image)
        self.canvas.delete("all")
        self.canvas.create_image(400, 300, image=photo)
        self.canvas.image = photo  # Keep reference
        
    def run(self):
        self.window.mainloop()

class ColorizationSystem:
    def __init__(self, model_path='best_model.pth', image_size=(256, 256)):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = SegColorizer().to(self.device)
        self.image_size = image_size
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path, map_location=self.device)
            state_dict = remove_orig_mod_prefix(checkpoint['model_state_dict'])
            self.model.load_state_dict(state_dict)
            self.model.eval()
        else:
            raise FileNotFoundError(f"Model file {model_path} not found!")
    
    def process_image(self, image_path, output_path, colorize_foreground=True, colorize_background=True):
        image = Image.open(image_path).convert('RGB')
        image = image.resize(self.image_size, Image.Resampling.LANCZOS)
        image_np = np.array(image)
        lab_image = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
        l_channel = lab_image[:, :, 0].astype(np.float32) / 50.0 - 1.0
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(self.image_size, antialias=True)
        ])
        l_tensor = transform(Image.fromarray(((l_channel + 1.0) * 50.0).astype(np.uint8))).unsqueeze(0)
        with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
            l_tensor = l_tensor.to(self.device)
            _, seg_output = self.model(l_tensor)
            mask_prob = torch.sigmoid(seg_output).cpu().numpy()[0, 0]
        if colorize_foreground and not colorize_background:
            blend = mask_prob
        elif colorize_background and not colorize_foreground:
            blend = 1 - mask_prob
        elif colorize_foreground and colorize_background:
            blend = np.ones_like(mask_prob)
        else:
            blend = np.zeros_like(mask_prob)
        binary_mask = (mask_prob > 0.5).astype(np.float32)
        mask_tensor = torch.from_numpy(binary_mask).unsqueeze(0).to(self.device)
        with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
            mask_tensor = mask_tensor.to(self.device)
            color_output, _ = self.model(l_tensor, mask_tensor)
            ab_channels = color_output.cpu().numpy()[0]
            ab_channels = ab_channels.transpose(1, 2, 0)
            ab_channels = ab_channels * 128.0 + 128.0
        colorized_lab = np.concatenate([lab_image[:, :, 0:1], ab_channels], axis=2)
        colorized_rgb = cv2.cvtColor(colorized_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
        gray_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
        gray_rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
        blend = blend[..., np.newaxis]
        final_result = blend * colorized_rgb + (1 - blend) * gray_rgb

        # --- Enhancement Step: Increase Saturation ---
        final_result = final_result.astype(np.uint8)
        final_result = enhance_saturation(final_result, factor=1.5)
        # -------------------------------------------------

        Image.fromarray(final_result).save(output_path)
        return output_path

def cli_interface():
    """Command-line Image Colorization Interface"""
    print("Command-line Image Colorization Interface")
    while True:
        image_path = input("Enter path to input image: ").strip().replace('\"', '')
        if os.path.exists(image_path):
            break
        print("File not found. Please try again.")
    output_path = input("Enter path for output image (press Enter for auto-generated): ").strip()
    if not output_path:
        output_dir = "output"
        os.makedirs(output_dir, exist_ok=True)
        base_name = os.path.splitext(os.path.basename(image_path))[0]
        output_path = os.path.join(output_dir, f"{base_name}_colorized.png")
    colorize_foreground = input("Colorize foreground? (y/n): ").lower().startswith('y')
    colorize_background = input("Colorize background? (y/n): ").lower().startswith('y')
    try:
        system = ColorizationSystem(image_size=(256, 256))
        result_path = system.process_image(image_path, output_path, colorize_foreground, colorize_background)
        print(f"Colorized image saved to: {result_path}")
    except Exception as e:
        print(f"Error processing image: {str(e)}")
        import traceback
        traceback.print_exc()

def check_display():
    """Check if a display server is available (for GUI)."""
    try:
        tk.Tk().destroy()
        return True
    except:
        return False

def main():
    if not os.path.exists('best_model.pth'):
        print("No pre-trained model found. Starting training...")
        download_dataset()
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        train_dataset = ColorizationDataset("datasets/val2017", transform=transform)
        val_dataset = ColorizationDataset("datasets/val2017", transform=transform)
        
        # For Windows, set num_workers to 0 to avoid multiprocessing issues.
        num_workers = 0 if sys.platform.startswith('win') else 2
        
        train_loader = DataLoader(
            train_dataset, 
            batch_size=16, 
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True
        )
        val_loader = DataLoader(
            val_dataset, 
            batch_size=32, 
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )
        model = SegColorizer()
        metrics = train_model(model, train_loader, val_loader)
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.plot(metrics['train_loss'], label='Train Loss')
        plt.plot(metrics['val_loss'], label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Loss over time')
        plt.subplot(1, 2, 2)
        plt.plot(metrics['precision'], label='Precision')
        plt.plot(metrics['recall'], label='Recall')
        plt.xlabel('Epoch')
        plt.ylabel('Score')
        plt.legend()
        plt.title('Metrics over time')
        plt.tight_layout()
        plt.savefig('training_metrics.png')
        plt.close()
        print("Training completed. Model saved as 'best_model.pth'")
    if check_display():
        print("Starting Interactive Colorization GUI...")
        app = InteractiveColorizationGUI()
        app.run()
    else:
        print("No display detected. Starting command-line interface...")
        cli_interface()

if __name__ == "__main__":
    main()


Starting Interactive Colorization GUI...


  checkpoint = torch.load(model_path, map_location=self.device)
