In [15]:
def test_deeper_model(model_path, vectorizer_path, data_path, test_ingredients, top_n=5):
    with open(vectorizer_path, "rb") as f:
        vectorizer = pickle.load(f)

    data = pd.read_csv(data_path)
    cocktail_labels = {name: idx for idx, name in enumerate(data['name'].unique())}
    idx_to_label = {idx: name for name, idx in cocktail_labels.items()}
    num_classes = len(cocktail_labels)

    input_size = len(vectorizer.get_feature_names_out())
    model = DeeperMultiLabelClassifier(input_size=input_size, num_classes=num_classes)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    test_vector = vectorizer.transform([' '.join(test_ingredients)]).toarray()
    test_tensor = torch.tensor(test_vector, dtype=torch.float32)

    with torch.no_grad():
        outputs = model(test_tensor).flatten()
        normalized_scores = outputs / outputs.sum()  # 추천 점수 정규화

    predicted_indices = torch.topk(normalized_scores, k=top_n).indices.numpy()

    recommendations = []
    for idx in predicted_indices:
        name = idx_to_label[idx]
        score = normalized_scores[idx].item()
        matching_ingredients = set(test_ingredients).intersection(
            set(data[data['name'] == name]['cleaned_ingredients'].values[0])
        )
        description = f"{name} contains {', '.join(matching_ingredients)} matching your input."
        recommendations.append((name, score, description))

    print("Top Recommendations:")
    for name, score, description in recommendations:
        print(f"{name}: {score:.2f}")
        print(f"Description: {description}")


# 실행
if __name__ == "__main__":
    test_ingredients = ["gin", "lemonjuice", "grenadine"]
    test_deeper_model(
        model_path="model/deeper_multi_label_model.pt",
        vectorizer_path="model/vectorizer.pkl",
        data_path="data/final_cocktails.csv",
        test_ingredients=test_ingredients,
        top_n=5
    )


Top Recommendations:
Lone Tree Cocktail: 0.38
Screwdriver: 0.32
Pink Gin: 0.30
Gin and Soda: 0.28
Rum Screwdriver: 0.27


  model.load_state_dict(torch.load(model_path))
