!xvfb-run -a -s "-screen 0 1024x768x24"

In [11]:
import os
if os.environ.get('DISPLAY', '') == '':
    os.environ.__setitem__('DISPLAY', ':0.0')
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import sqlite3
import bcrypt
import datetime
import numpy as np
from PIL import Image, ImageTk
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from facenet_pytorch import InceptionResnetV1

# Configuration
DATABASE_PATH = "criminal_sketch_system.db"
SKETCH_STORAGE = "uploaded_sketches/"
GENERATED_STORAGE = "generated_images/"
MODEL_PATH = "sketch2face_model.pth"

# Ensure directories exist
os.makedirs(SKETCH_STORAGE, exist_ok=True)
os.makedirs(GENERATED_STORAGE, exist_ok=True)

# Initialize database
def init_database():
    conn = sqlite3.connect(DATABASE_PATH)
    cursor = conn.cursor()

    # Create users table
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS users (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        username TEXT UNIQUE NOT NULL,
        password TEXT NOT NULL,
        role TEXT NOT NULL,
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    )''')

    # Create logs table
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS logs (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        user_id INTEGER,
        sketch_path TEXT,
        generated_path TEXT,
        timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
        FOREIGN KEY (user_id) REFERENCES users (id)
    )''')

    # Create criminal database table
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS criminal_database (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        image_path TEXT,
        name TEXT,
        criminal_id TEXT UNIQUE,
        details TEXT,
        added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    )''')

    conn.commit()
    conn.close()

# Sketch to Face Generator Model
class Sketch2FaceModel(nn.Module):
    def __init__(self):
        super(Sketch2FaceModel, self).__init__()
        # Encoder (U-Net style)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

class FaceComparisonModel:
    def __init__(self):
        self.model = InceptionResnetV1(pretrained='vggface2').eval()
        self.transform = transforms.Compose([
            transforms.Resize((160, 160)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    def get_embedding(self, img):
        img = self.transform(img).unsqueeze(0)
        with torch.no_grad():
            embedding = self.model(img)
        return embedding

    def compare_faces(self, img1, img2):
        embedding1 = self.get_embedding(img1)
        embedding2 = self.get_embedding(img2)

        # Calculate cosine similarity
        similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2)
        return similarity.item()

class LoginFrame(ttk.Frame):
    def __init__(self, parent, controller):
        super().__init__(parent)
        self.controller = controller

        # Configure style
        style = ttk.Style()
        style.configure("TFrame", background="#f0f0f0")
        style.configure("TButton", background="#4CAF50", foreground="black", font=("Arial", 12))
        style.configure("TLabel", background="#f0f0f0", font=("Arial", 12))
        style.configure("TEntry", font=("Arial", 12))

        # Set up the frame
        self.configure(style="TFrame", padding=(20, 20))

        # Title
        title_label = ttk.Label(self, text="Criminal Sketch Analysis System",
                                font=("Arial", 20, "bold"), style="TLabel")
        title_label.pack(pady=20)

        # Username
        username_frame = ttk.Frame(self, style="TFrame")
        username_frame.pack(fill="x", pady=5)

        username_label = ttk.Label(username_frame, text="Username:", style="TLabel")
        username_label.pack(side="left", padx=5)

        self.username_var = tk.StringVar()
        username_entry = ttk.Entry(username_frame, textvariable=self.username_var, width=25)
        username_entry.pack(side="right", padx=5)

        # Password
        password_frame = ttk.Frame(self, style="TFrame")
        password_frame.pack(fill="x", pady=5)

        password_label = ttk.Label(password_frame, text="Password:", style="TLabel")
        password_label.pack(side="left", padx=5)

        self.password_var = tk.StringVar()
        password_entry = ttk.Entry(password_frame, textvariable=self.password_var,
                                 show="*", width=25)
        password_entry.pack(side="right", padx=5)

        # Buttons
        button_frame = ttk.Frame(self, style="TFrame")
        button_frame.pack(pady=20)

        login_button = ttk.Button(button_frame, text="Login", command=self.login)
        login_button.pack(side="left", padx=10)

        register_button = ttk.Button(button_frame, text="Register", command=self.show_register)
        register_button.pack(side="right", padx=10)

    def login(self):
        username = self.username_var.get()
        password = self.password_var.get()

        if not username or not password:
            messagebox.showerror("Error", "Please enter both username and password")
            return

        # Verify credentials
        conn = sqlite3.connect(DATABASE_PATH)
        cursor = conn.cursor()

        cursor.execute("SELECT id, password, role FROM users WHERE username = ?", (username,))
        user = cursor.fetchone()

        conn.close()

        if not user:
            messagebox.showerror("Error", "Invalid username")
            return

        user_id, hashed_pw, role = user

        # Check password
        if bcrypt.checkpw(password.encode('utf-8'), hashed_pw):
            self.controller.user_id = user_id
            self.controller.username = username
            self.controller.role = role
            self.controller.show_frame("MainFrame")
        else:
            messagebox.showerror("Error", "Invalid password")

    def show_register(self):
        self.controller.show_frame("RegisterFrame")

class RegisterFrame(ttk.Frame):
    def __init__(self, parent, controller):
        super().__init__(parent)
        self.controller = controller

        # Configure style
        style = ttk.Style()
        self.configure(style="TFrame", padding=(20, 20))

        # Title
        title_label = ttk.Label(self, text="Register New User",
                                font=("Arial", 18, "bold"), style="TLabel")
        title_label.pack(pady=15)

        # Username
        username_frame = ttk.Frame(self, style="TFrame")
        username_frame.pack(fill="x", pady=5)

        username_label = ttk.Label(username_frame, text="Username:", style="TLabel")
        username_label.pack(side="left", padx=5)

        self.username_var = tk.StringVar()
        username_entry = ttk.Entry(username_frame, textvariable=self.username_var, width=25)
        username_entry.pack(side="right", padx=5)

        # Password
        password_frame = ttk.Frame(self, style="TFrame")
        password_frame.pack(fill="x", pady=5)

        password_label = ttk.Label(password_frame, text="Password:", style="TLabel")
        password_label.pack(side="left", padx=5)

        self.password_var = tk.StringVar()
        password_entry = ttk.Entry(password_frame, textvariable=self.password_var,
                                 show="*", width=25)
        password_entry.pack(side="right", padx=5)

        # Confirm Password
        confirm_frame = ttk.Frame(self, style="TFrame")
        confirm_frame.pack(fill="x", pady=5)

        confirm_label = ttk.Label(confirm_frame, text="Confirm Password:", style="TLabel")
        confirm_label.pack(side="left", padx=5)

        self.confirm_var = tk.StringVar()
        confirm_entry = ttk.Entry(confirm_frame, textvariable=self.confirm_var,
                                show="*", width=25)
        confirm_entry.pack(side="right", padx=5)

        # Role selection
        role_frame = ttk.Frame(self, style="TFrame")
        role_frame.pack(fill="x", pady=5)

        role_label = ttk.Label(role_frame, text="Role:", style="TLabel")
        role_label.pack(side="left", padx=5)

        self.role_var = tk.StringVar(value="officer")
        roles = [("Officer", "officer"), ("Admin", "admin")]

        role_options = ttk.Frame(role_frame, style="TFrame")
        role_options.pack(side="right")

        for text, value in roles:
            ttk.Radiobutton(role_options, text=text, value=value,
                          variable=self.role_var).pack(side="left", padx=10)

        # Buttons
        button_frame = ttk.Frame(self, style="TFrame")
        button_frame.pack(pady=15)

        register_button = ttk.Button(button_frame, text="Register", command=self.register)
        register_button.pack(side="left", padx=10)

        back_button = ttk.Button(button_frame, text="Back to Login",
                               command=lambda: controller.show_frame("LoginFrame"))
        back_button.pack(side="right", padx=10)

    def register(self):
        username = self.username_var.get()
        password = self.password_var.get()
        confirm_password = self.confirm_var.get()
        role = self.role_var.get()

        # Validation
        if not username or not password:
            messagebox.showerror("Error", "Please fill out all fields")
            return

        if password != confirm_password:
            messagebox.showerror("Error", "Passwords do not match")
            return

        if len(password) < 6:
            messagebox.showerror("Error", "Password must be at least 6 characters long")
            return

        # Hash password
        hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt())

        # Save to database
        conn = sqlite3.connect(DATABASE_PATH)
        cursor = conn.cursor()

        try:
            cursor.execute(
                "INSERT INTO users (username, password, role) VALUES (?, ?, ?)",
                (username, hashed_password, role)
            )
            conn.commit()
            messagebox.showinfo("Success", "Registration successful! You can now login.")
            self.controller.show_frame("LoginFrame")
        except sqlite3.IntegrityError:
            messagebox.showerror("Error", "Username already exists")
        finally:
            conn.close()

class MainFrame(ttk.Frame):
    def __init__(self, parent, controller):
        super().__init__(parent)
        self.controller = controller
        self.sketch_path = None
        self.generated_image = None

        # Load models
        self.load_models()

        # Configure style
        self.configure(style="TFrame", padding=(20, 20))

        # Create main layout
        self.create_header()
        self.create_content_area()
        self.create_footer()

    def load_models(self):
        # Initialize the sketch to face model
        self.s2f_model = Sketch2FaceModel()
        if os.path.exists(MODEL_PATH):
            self.s2f_model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
        self.s2f_model.eval()

        # Initialize face comparison model
        self.face_comparison = FaceComparisonModel()

        # Image transforms
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def create_header(self):
        header_frame = ttk.Frame(self, style="TFrame")
        header_frame.pack(fill="x", pady=(0, 20))

        # Title
        title_label = ttk.Label(header_frame, text="Criminal Sketch Analysis System",
                              font=("Arial", 16, "bold"), style="TLabel")
        title_label.pack(side="left")

        # User info and logout
        user_frame = ttk.Frame(header_frame, style="TFrame")
        user_frame.pack(side="right")

        self.user_label = ttk.Label(user_frame, text="", style="TLabel")
        self.user_label.pack(side="left", padx=10)

        logout_button = ttk.Button(user_frame, text="Logout",
                                 command=self.logout)
        logout_button.pack(side="right")

    def create_content_area(self):
        content_frame = ttk.Frame(self, style="TFrame")
        content_frame.pack(fill="both", expand=True)

        # Left panel - Sketch upload
        left_panel = ttk.LabelFrame(content_frame, text="Upload Sketch", style="TFrame", padding=10)
        left_panel.pack(side="left", fill="both", expand=True, padx=(0, 10))

        upload_button = ttk.Button(left_panel, text="Select Sketch",
                                 command=self.upload_sketch)
        upload_button.pack(pady=10)

        self.sketch_canvas = tk.Canvas(left_panel, width=300, height=300, bg="white",
                                     highlightbackground="gray", highlightthickness=1)
        self.sketch_canvas.pack(pady=10)

        generate_button = ttk.Button(left_panel, text="Generate Face Image",
                                   command=self.generate_face)
        generate_button.pack(pady=10)

        # Right panel - Generated image and matches
        right_panel = ttk.LabelFrame(content_frame, text="Results", style="TFrame", padding=10)
        right_panel.pack(side="right", fill="both", expand=True, padx=(10, 0))

        results_top = ttk.Frame(right_panel, style="TFrame")
        results_top.pack(fill="x")

        # Generated image section
        generated_frame = ttk.LabelFrame(results_top, text="Generated Face", style="TFrame", padding=5)
        generated_frame.pack(side="left", fill="both", expand=True, padx=(0, 5))

        self.generated_canvas = tk.Canvas(generated_frame, width=300, height=300, bg="white",
                                        highlightbackground="gray", highlightthickness=1)
        self.generated_canvas.pack(pady=5)

        compare_button = ttk.Button(generated_frame, text="Find Matches",
                                  command=self.find_matches)
        compare_button.pack(pady=5)

        # Matches section
        matches_frame = ttk.LabelFrame(right_panel, text="Potential Matches", style="TFrame", padding=5)
        matches_frame.pack(fill="both", expand=True, pady=10)

        # Create scrollable frame for matches
        matches_canvas = tk.Canvas(matches_frame, highlightthickness=0)
        scrollbar = ttk.Scrollbar(matches_frame, orient="vertical", command=matches_canvas.yview)

        self.matches_container = ttk.Frame(matches_canvas, style="TFrame")

        matches_canvas.configure(yscrollcommand=scrollbar.set)

        scrollbar.pack(side="right", fill="y")
        matches_canvas.pack(side="left", fill="both", expand=True)

        matches_canvas.create_window((0, 0), window=self.matches_container, anchor="nw")
        self.matches_container.bind("<Configure>",
                                  lambda e: matches_canvas.configure(
                                      scrollregion=matches_canvas.bbox("all")))

    def create_footer(self):
        footer_frame = ttk.Frame(self, style="TFrame")
        footer_frame.pack(fill="x", pady=(20, 0))

        status_label = ttk.Label(footer_frame, text="Ready", style="TLabel")
        status_label.pack(side="left")

        # Add timestamp
        time_label = ttk.Label(footer_frame, text=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                             style="TLabel")
        time_label.pack(side="right")

    def update_user_info(self):
        if hasattr(self.controller, 'username') and hasattr(self.controller, 'role'):
            self.user_label.config(
                text=f"Logged in as: {self.controller.username} ({self.controller.role})"
            )

    def logout(self):
        self.controller.user_id = None
        self.controller.username = None
        self.controller.role = None
        self.controller.show_frame("LoginFrame")

    def upload_sketch(self):
        file_path = filedialog.askopenfilename(
            title="Select Sketch Image",
            filetypes=[("Image files", "*.png;*.jpg;*.jpeg;*.bmp")]
        )

        if not file_path:
            return

        try:
            # Save a copy to the storage directory
            filename = os.path.basename(file_path)
            timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            new_filename = f"{timestamp}_{filename}"
            sketch_save_path = os.path.join(SKETCH_STORAGE, new_filename)

            # Open and resize image
            img = Image.open(file_path)
            img = img.resize((300, 300), Image.LANCZOS)
            img.save(sketch_save_path)

            # Display in canvas
            self.sketch_tk = ImageTk.PhotoImage(img)
            self.sketch_canvas.create_image(150, 150, image=self.sketch_tk)

            self.sketch_path = sketch_save_path
            messagebox.showinfo("Success", "Sketch uploaded successfully")

        except Exception as e:
            messagebox.showerror("Error", f"Failed to load image: {str(e)}")

    def generate_face(self):
        if not self.sketch_path:
            messagebox.showerror("Error", "Please upload a sketch first")
            return

        try:
            # Load and preprocess sketch
            sketch = Image.open(self.sketch_path)
            sketch = sketch.resize((256, 256), Image.LANCZOS)
            sketch_tensor = self.transform(sketch).unsqueeze(0)

            # Generate face image
            with torch.no_grad():
                generated = self.s2f_model(sketch_tensor)

            # Convert to image
            generated = generated.squeeze(0).detach()
            generated = (generated * 0.5 + 0.5) * 255
            generated = generated.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
            generated_img = Image.fromarray(generated)

            # Save generated image
            timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            filename = os.path.basename(self.sketch_path)
            generated_filename = f"generated_{timestamp}_{filename}"
            generated_path = os.path.join(GENERATED_STORAGE, generated_filename)
            generated_img.save(generated_path)

            # Resize for display
            display_img = generated_img.resize((300, 300), Image.LANCZOS)
            self.generated_tk = ImageTk.PhotoImage(display_img)
            self.generated_canvas.create_image(150, 150, image=self.generated_tk)

            self.generated_image = generated_img
            self.generated_path = generated_path

            # Log the generation
            self.log_generation()

            messagebox.showinfo("Success", "Face image generated successfully")

        except Exception as e:
            messagebox.showerror("Error", f"Failed to generate face: {str(e)}")

    def log_generation(self):
        # Save to log
        if hasattr(self.controller, 'user_id') and self.sketch_path and hasattr(self, 'generated_path'):
            conn = sqlite3.connect(DATABASE_PATH)
            cursor = conn.cursor()

            cursor.execute(
                "INSERT INTO logs (user_id, sketch_path, generated_path) VALUES (?, ?, ?)",
                (self.controller.user_id, self.sketch_path, self.generated_path)
            )

            conn.commit()
            conn.close()

    def find_matches(self):
        if not hasattr(self, 'generated_image') or self.generated_image is None:
            messagebox.showerror("Error", "Please generate a face image first")
            return

        # Clear previous matches
        for widget in self.matches_container.winfo_children():
            widget.destroy()

        try:
            # Get all criminal database images
            conn = sqlite3.connect(DATABASE_PATH)
            cursor = conn.cursor()

            cursor.execute("SELECT id, image_path, name, criminal_id, details FROM criminal_database")
            db_images = cursor.fetchall()

            conn.close()

            if not db_images:
                ttk.Label(self.matches_container, text="No criminal database entries found. Please add some first.",
                        style="TLabel").pack(pady=10)
                return

            # For demo, let's simulate some matches
            # In a real app, we'd use the comparison model

            # Example matches with simulated similarity scores
            matches = []
            for criminal_id, img_path, name, criminal_code, details in db_images:
                if os.path.exists(img_path):
                    db_image = Image.open(img_path)

                    # Calculate similarity
                    similarity = self.face_comparison.compare_faces(self.generated_image, db_image)
                    similarity_percentage = min(100, max(0, similarity * 100))

                    matches.append((criminal_id, img_path, name, similarity_percentage, criminal_code, details))

            # Sort by similarity (highest first)
            matches.sort(key=lambda x: x[3], reverse=True)

            # Display top matches
            for i, (criminal_id, img_path, name, similarity, criminal_code, details) in enumerate(matches[:5]):
                match_frame = ttk.Frame(self.matches_container, style="TFrame")
                match_frame.pack(fill="x", pady=5)

                # Load image
                img = Image.open(img_path)
                img = img.resize((100, 100), Image.LANCZOS)
                img_tk = ImageTk.PhotoImage(img)

                # Store reference to avoid garbage collection
                match_frame.img_tk = img_tk

                # Image
                img_label = ttk.Label(match_frame, image=img_tk)
                img_label.pack(side="left", padx=10)

                # Details
                details_frame = ttk.Frame(match_frame, style="TFrame")
                details_frame.pack(side="left", fill="both", expand=True)

                ttk.Label(details_frame, text=f"Name: {name}", style="TLabel").pack(anchor="w")
                ttk.Label(details_frame, text=f"ID: {criminal_code}", style="TLabel").pack(anchor="w")
                ttk.Label(details_frame, text=f"Similarity: {similarity:.2f}%",
                        style="TLabel", foreground="red" if similarity > 80 else "black").pack(anchor="w")

                # View detail button
                ttk.Button(match_frame, text="View Details",
                         command=lambda cid=criminal_id: self.show_criminal_details(cid)).pack(side="right", padx=10)

            if not matches:
                ttk.Label(self.matches_container, text="No matches found in the database.",
                        style="TLabel").pack(pady=10)

        except Exception as e:
            messagebox.showerror("Error", f"Failed to find matches: {str(e)}")

    def show_criminal_details(self, criminal_id):
        # Get criminal details
        conn = sqlite3.connect(DATABASE_PATH)
        cursor = conn.cursor()

        cursor.execute(
            "SELECT image_path, name, criminal_id, details FROM criminal_database WHERE id = ?",
            (criminal_id,)
        )
        criminal = cursor.fetchone()

        conn.close()

        if not criminal:
            messagebox.showerror("Error", "Criminal record not found")
            return

        # Create details window
        details_window = tk.Toplevel(self)
        details_window.title("Criminal Details")
        details_window.geometry("500x400")

        img_path, name, criminal_code, details = criminal

        # Load image
        img = Image.open(img_path)
        img = img.resize((150, 150), Image.LANCZOS)
        img_tk = ImageTk.PhotoImage(img)

        # Store reference
        details_window.img_tk = img_tk

        ttk.Label(details_window, image=img_tk).pack(pady=10)
        ttk.Label(details_window, text=f"Name: {name}", font=("Arial", 14, "bold")).pack()
        ttk.Label(details_window, text=f"ID: {criminal_code}").pack()

        # Details text
        ttk.Label(details_window, text="Details:", font=("Arial", 12, "bold")).pack(anchor="w", padx=20, pady=(10, 0))

        details_text = tk.Text(details_window, height=8, width=50)
        details_text.pack(padx=20, pady=5, fill="both", expand=True)
        details_text.insert("1.0", details)
        details_text.config(state="disabled")

        # Close button
        ttk.Button(details_window, text="Close", command=details_window.destroy).pack(pady=10)

class CriminalSketchApp(tk.Tk):
    def __init__(self):
        super().__init__()

        self.title("Criminal Sketch Analysis System")
        self.geometry("1200x800")
        self.resizable(True, True)

        # Initialize database
        init_database()

        # Container for frames
        container = ttk.Frame(self)
        container.pack(side="top", fill="both", expand=True)

        # Initialize frames dictionary
        self.frames = {}

        # Create frames
        for F in (LoginFrame, RegisterFrame, MainFrame):
            frame = F(container, self)
            self.frames[F.__name__] = frame
            frame.grid(row=0, column=0, sticky="nsew")

        # Show login frame
        self.show_frame("LoginFrame")

    def show_frame(self, frame_name):
        frame = self.frames[frame_name]
        frame.tkraise()

        # Update user info if showing main frame
        if frame_name == "MainFrame":
            frame.update_user_info()

if __name__ == "__main__":
    app = CriminalSketchApp()
    app.mainloop()


TclError: couldn't connect to display ":0"