In [2]:
from itertools import product

In [5]:
import pandas as pd
import pycosat
import plotly.graph_objects as go

# Constants
K = 5  # Number of houses
NATIONALITY = ['NORWEGIAN', 'UKRAINIAN', 'ENGLISHMAN', 'SPANIARD', 'JAPANESE']
COLOR = ['RED', 'GREEN', 'IVORY', 'YELLOW', 'BLUE']
DRINK = ['COFFEE', 'TEA', 'MILK', 'ORANGE_JUICE', 'WATER']
PET = ['DOG', 'SNAILS', 'FOX', 'HORSE', 'ZEBRA']
CIGARETTE = ['OLD_GOLD', 'KOOLS', 'CHESTERFIELDS', 'LUCKY_STRIKE', 'PARLIAMENTS']

ATTRIBUTES = [NATIONALITY, COLOR, DRINK, PET, CIGARETTE]

def var(house, attribute_type, value):
    return (house * len(ATTRIBUTES) + attribute_type) * K + value + 1

def unique_attribute():
    formula = []
    for attribute_type, values in enumerate(ATTRIBUTES):
        for house in range(K):
            # At least one attribute value per house
            formula.append([var(house, attribute_type, v) for v in range(K)])
            # No more than one attribute value per house
            for v1, v2 in product(range(K), repeat=2):
                if v1 < v2:
                    formula.append([-var(house, attribute_type, v1), -var(house, attribute_type, v2)])
        # Each attribute value is assigned to exactly one house
        for v in range(K):
            # At least one house per attribute value
            formula.append([var(house, attribute_type, v) for house in range(K)])
            # No more than one house per attribute value
            for house1, house2 in product(range(K), repeat=2):
                if house1 < house2:
                    formula.append([-var(house1, attribute_type, v), -var(house2, attribute_type, v)])
    return formula

def bi_implication(p, q):
    """Returns a bi-implication (iff) in CNF."""
    return [[-p, q], [p, -q]]


def encode_clues():
    formula = []
    
    # The Norwegian lives in the first house
    formula.append([var(0, 0, 0)])
    
    # The milk is drunk in the middle house
    formula.append([var(2, 2, 2)])
    
    # The Englishman lives in the red house
    formula.append([var(0, 1, 0), var(1, 1, 0), var(2, 1, 0), var(3, 1, 0), var(4, 1, 0)])
    
    # The green house is immediately to the right of the ivory house
    for house in range(4):
        formula.append([-var(house, 1, 2), var(house + 1, 1, 1)])
    
    # The Norwegian lives next to the blue house
    formula.append([-var(0, 0, 0), var(1, 1, 4)])
    
    # The Spaniard owns the dog
    for house in range(5):
        formula.extend(bi_implication(var(house, 0, 3), var(house, 3, 0)))
        
    # Coffee is drunk in the green house
    for house in range(5):
        formula.extend(bi_implication(var(house, 2, 0), var(house, 1, 1)))
        
    # The Ukrainian drinks tea
    for house in range(5):
        formula.extend(bi_implication(var(house, 0, 1), var(house, 2, 1)))
        
    # The Old Gold smoker owns snails
    for house in range(5):
        formula.extend(bi_implication(var(house, 4, 0), var(house, 3, 1)))
        
    # Kools are smoked in the yellow house
    for house in range(5):
        formula.extend(bi_implication(var(house, 4, 1), var(house, 1, 3)))
    
    # The Lucky Strike smoker drinks orange juice
    for house in range(5):
        formula.extend(bi_implication(var(house, 4, 3), var(house, 2, 3)))
        
    # The Japanese smokes Parliaments
    for house in range(5):
        formula.extend(bi_implication(var(house, 0, 4), var(house, 4, 4)))
        
    # The Chesterfields smoker is next to the fox owner
    for house in range(4):
        formula.append([-var(house, 4, 2), var(house + 1, 3, 2)])
        formula.append([-var(house + 1, 4, 2), var(house, 3, 2)])
        
    # Kools smoker is next to the horse owner
    for house in range(4):
        formula.append([-var(house, 4, 1), var(house + 1, 3, 3)])
        formula.append([-var(house + 1, 4, 1), var(house, 3, 3)])
    
    return formula


def solve():
    formula = unique_attribute()
    formula.extend(encode_clues())
    solution = pycosat.solve(formula)
    if solution == "UNSAT":
        return None
    else:
        solution = [s for s in solution if s > 0]
        return solution

def decode_solution(solution):
    dtf = pd.DataFrame(columns=["House", "Nationality", "Color", "Drink", "Pet", "Cigarette"])
    for s in solution:
        house = (s - 1) // (K * len(ATTRIBUTES))
        attribute_type = ((s - 1) % (K * len(ATTRIBUTES))) // K
        value = (s - 1) % K
        dtf.at[house, dtf.columns[attribute_type + 1]] = ATTRIBUTES[attribute_type][value]
    dtf["House"] = dtf.index + 1
    return dtf

def plot_knowledge_graph(dtf):
    edge_x = []
    edge_y = []
    node_x = []
    node_y = []
    text = []

    for index, row in dtf.iterrows():
        node_x.append(index)
        node_y.append(2)
        text.append(f'House {row["House"]}')

        for col in dtf.columns[1:]:
            node_x.append(index)
            node_y.append(1)
            text.append(row[col])
            edge_x.extend([index, index, None])
            edge_y.extend([2, 1, None])

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        hoverinfo='text',
        marker=dict(showscale=True,
                    colorscale='YlGnBu',
                    reversescale=True,
                    color=[],
                    size=10,
                    colorbar=dict(
                        thickness=15,
                        title='Node Connections',
                        xanchor='left',
                        titleside='right'
                    ),
                    line_width=2))
    node_trace.text = text

    node_adjacencies = []
    for node, adjacencies in enumerate(dtf.values):
        node_adjacencies.append(len(adjacencies))
    node_trace.marker.color = node_adjacencies

    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Knowledge Graph of Einstein\'s Riddle',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=0, l=0, r=0, t=40),
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )
    fig.show()

if __name__ == "__main__":
    solution = solve()
    if solution:
        dtf_solution = decode_solution(solution)
        print(dtf_solution)
        plot_knowledge_graph(dtf_solution)
    else:
        print("No solution found!")


No solution found!
