<a href="https://colab.research.google.com/github/Mushrooooxyoooom/Your-TMJ-Mouthguard/blob/main/Project_TDM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchvision
!pip install torch
!pip install transformers torch

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.3.1->torchvision)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.3.1->torchvision)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.3.1->torchvision)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.3.1->torchvision)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.3.1->torchvision)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.3.1->torchvision)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.me

In [24]:
from PIL import Image
import torchvision.transforms as transforms
from torchvision.models import resnet50
import torch
import requests
from IPython.display import clear_output, display, HTML
from transformers import pipeline

# 加载预训练的 ResNet50 模型
model = resnet50(pretrained=True)
model.eval()

# 下载 ImageNet 标签文件
def load_imagenet_labels():
    url = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json'
    response = requests.get(url)
    labels = response.json()
    return labels

# 图像预处理函数
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)

# 使用 ResNet50 识别图像中的物体，并返回前K个预测结果
def identify_ingredients(image_path, top_k=2):
    image_tensor = preprocess_image(image_path)
    with torch.no_grad():
        outputs = model(image_tensor)
    probs = torch.nn.functional.softmax(outputs[0], dim=0)
    top_probs, top_idxs = torch.topk(probs, top_k)

    labels = load_imagenet_labels()  # 加载标签
    return [(labels[idx], prob.item()) for idx, prob in zip(top_idxs, top_probs)]

# 示例运行
from google.colab import files

# 上传文件
uploaded = files.upload()

# 获取上传文件的文件名
image_path = list(uploaded.keys())[0]

# 识别图像中的物体，并返回前2个预测结果
identified_ingredients = identify_ingredients(image_path, top_k=2)

# 定义默认配料
default_ingredients = {
    "Wheat flour": "17.5g",
    "Millet flour": "2.5g",
    "Tapioca starch": "3g",
    "Rice flour": "3g",
    "Glycerin": "1.5g"
}

# 定义汁液的量
juice_amount = "12.5g"

# 直接生成最终配料表
def generate_final_ingredient_list(identified_ingredients):
    ingredient_list = default_ingredients.copy()

    if len(identified_ingredients) >= 2:
        # 提取两个食材并生成 OR 连接的汁液项
        juice_ingredient1 = f"{identified_ingredients[0][0]}'s juice"
        juice_ingredient2 = f"{identified_ingredients[1][0]}'s juice"
        combined_juice = f"{juice_ingredient1} OR {juice_ingredient2}"
        ingredient_list[combined_juice] = juice_amount

    return ingredient_list

# 生成最终的配料表
final_ingredient_list = generate_final_ingredient_list(identified_ingredients)

# 输出最终的配料表（加粗、放大、分级、居中）
display(HTML("<h1 style='text-align:center;'>INGREDIENT LIST:</h1>"))
for ingredient, amount in final_ingredient_list.items():
    display(HTML(f"<p style='text-align:center;'><strong>{ingredient}:</strong> {amount}</p>"))

# 用户输入修改的配料及其新克数
def get_user_modification():
    print("\nPlease enter the ingredients to be modified and the new number of grams.（e.g. 'Wheat flour: 20g'）：")
    modification_input = input()
    ingredient, new_amount = modification_input.split(': ')
    return ingredient, new_amount

# 修改配料并按比例调整其他配料
def modify_ingredient(ingredient_list, modified_ingredient, new_amount):
    # 提取原始量
    original_amount = float(new_amount.replace('g', ''))

    # 计算比例因子
    reference_amount = float(ingredient_list[modified_ingredient].replace('g', ''))
    ratio = original_amount / reference_amount

    # 修改配料量
    modified_list = ingredient_list.copy()
    for ingredient in modified_list:
        if ingredient != modified_ingredient:
            # 将其他配料调整为同比例
            amount = float(modified_list[ingredient].replace('g', ''))
            modified_list[ingredient] = f"{amount * ratio:.1f}g"

    # 更新修改的配料
    modified_list[modified_ingredient] = f"{original_amount:.1f}g"

    return modified_list

# 获取用户的配料修改输入
user_modified_ingredient, user_new_amount = get_user_modification()

# 清除先前的输出
clear_output()

# 确保用户修改的配料在配料表中存在
if user_modified_ingredient in final_ingredient_list:
    updated_ingredient_list = modify_ingredient(final_ingredient_list, user_modified_ingredient, user_new_amount)
else:
    print(f"Error: Ingredient '{user_modified_ingredient}' is not in the ingredient list.")
    updated_ingredient_list = final_ingredient_list

# 输出修改后的配料表（加粗、放大、分级、居中）
display(HTML("<h1 style='text-align:center;'>INGREDIENT LIST:</h1>"))
for ingredient, amount in updated_ingredient_list.items():
    display(HTML(f"<p style='text-align:center;'><strong>{ingredient}:</strong> {amount}</p>"))

# 初始化 FLAN-T5 模型
def initialize_model():
    pipe = pipeline("text2text-generation", model="google/flan-t5-small", device=-1)  # 0 代表使用第一个 GPU
    return pipe

# 使用模型提取汁液颜色
def get_juice_color(ingredient_name, model):
    prompt = f"What is the juice color of {ingredient_name}?"
    result = model(prompt)
    return result[0]['generated_text'].strip()

# 初始化 FLAN-T5 模型
model = initialize_model()

# 获取汁液颜色
colors = []
for ingredient, _ in identified_ingredients:
    color = get_juice_color(ingredient, model)
    colors.append((ingredient, color))

# 生成最终句子
if len(colors) == 2:
    ingredient1, color1 = colors[0]
    ingredient2, color2 = colors[1]
    sentence = f"You'll get a {color1}, {ingredient1}-flavored or a {color2}, {ingredient2}-flavored TMJ mouthguard."
    display(HTML(f"<h1 style='text-align:center;'>{sentence}</h1>"))
else:
    print("Error: The number of identified ingredients is not 2.")

# 颜色名称到十六进制代码的映射
color_map = {
    'red': '#FF0000',
    'green': '#00FF00',
    'blue': '#0000FF',
    'yellow': '#FFFF00',
    'purple': '#800080',
    'cyan': '#00FFFF',
    'magenta': '#FF00FF',
    'white': '#FFFFFF',
    'black': '#000000',
    'orange': '#FFA500',
    'pink': '#FFC0CB',
    'brown': '#A52A2A',
    'gray': '#808080',
    'lightblue': '#ADD8E6',
    'darkgreen': '#006400'
}

# 将颜色名称转换为十六进制代码
def get_hex_color(color_name):
    return color_map.get(color_name.lower(), '#FFFFFF')

hex_color = 0x0

# 获取汁液颜色并转换为十六进制颜色代码
for ingredient, _ in identified_ingredients:
    color_name = get_juice_color(ingredient, model)
    hex_color = get_hex_color(color_name)

# 将颜色代码格式化为正确的格式
hex_without_hash = hex_color.lstrip('#')  # 去掉 #
hex_formatted = '0x' + hex_without_hash  # 加上 0x 前缀

# 生成 HTML 内容
html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>YOUR TMJ MOUTHGUARD</title>
    <style>
        body {{
            margin: 0;
            overflow: hidden;
            font-family: Arial, sans-serif;
        }}
        h1 {{
            text-align: center;
            font-size: 2.5em;
            font-weight: bold;
            margin-top: 20px;
        }}
        p {{
            text-align: center;
            font-size: 1.2em;
        }}
    </style>
</head>
<body>
    <h1>YOUR TMJ MOUTHGUARD</h1>
    <div id="scene-container"></div>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/loaders/GLTFLoader.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
    <script>
        const scene = new THREE.Scene();
        const camera = new THREE.PerspectiveCamera(75, window.innerWidth / window.innerHeight, 0.1, 1000);
        const renderer = new THREE.WebGLRenderer();
        renderer.setSize(window.innerWidth, window.innerHeight);
        renderer.setClearColor(0xffffff); // 设置背景颜色为白色
        document.getElementById('scene-container').appendChild(renderer.domElement);

        const loader = new THREE.GLTFLoader();
        let mixer;
        loader.load(
            'https://raw.githubusercontent.com/Mushrooooxyoooom/Your-TMJ-Mouthguard/main/food_small.gltf', // URL to the 3D model
            function (gltf) {{
                const model = gltf.scene;
                scene.add(model);
                model.position.set(0, -1, 0); // Position adjustment
                model.scale.set(0.5, 0.5, 0.5); // Scale down to 50%

                // Set up animations
                mixer = new THREE.AnimationMixer(model);
                const clips = gltf.animations;
                clips.forEach((clip) => {{
                    mixer.clipAction(clip).play();
                }});

                // Change model color to green with adjusted saturation and brightness
                changeModelColor(
                    {hex_formatted}
                ); // Change Original color
            }},
            undefined,
            function (error) {{
                console.error(error);
            }}
        );

        const light = new THREE.AmbientLight(0x404040); // Soft white light
        scene.add(light);

        const directionalLight1 = new THREE.DirectionalLight(0xffffff, 2); // Intensity 2
        directionalLight1.position.set(1, 1, 1).normalize();
        scene.add(directionalLight1);
        directionalLight1.castShadow = true;
        directionalLight1.shadow.mapSize.width = 1024; // Shadow map width
        directionalLight1.shadow.mapSize.height = 1024; // Shadow map height
        directionalLight1.shadow.camera.near = 0.5; // Near clipping plane
        directionalLight1.shadow.camera.far = 50; // Far clipping plane
        directionalLight1.shadow.camera.left = -10; // Left clipping plane
        directionalLight1.shadow.camera.right = 10; // Right clipping plane
        directionalLight1.shadow.camera.top = 10; // Top clipping plane
        directionalLight1.shadow.camera.bottom = -10; // Bottom clipping plane

        const directionalLight2 = new THREE.DirectionalLight(0xffffff, 2); // Intensity 2
        directionalLight2.position.set(-1, -1, -1).normalize();
        scene.add(directionalLight2);
        directionalLight2.castShadow = true;

        camera.position.z = 5;

        const controls = new THREE.OrbitControls(camera, renderer.domElement);
        controls.enableDamping = true; // Damping is required for smooth animations
        controls.dampingFactor = 0.25;
        controls.enableZoom = true;
        controls.enablePan = true;

        function animate() {{
            requestAnimationFrame(animate);

            if (mixer) {{
                mixer.update(0.01); // Update animations
            }}

            controls.update(); // Update controls
            renderer.render(scene, camera);
        }}
        animate();

        function changeModelColor(colorHex) {{
            const color = new THREE.Color(colorHex);

            const hsv = {{}};
            color.getHSL(hsv);
            hsv.s = 0.8; // Decrease saturation (0 to 1)
            hsv.l = 0.7; // Increase brightness (0 to 1)

            const adjustedColor = new THREE.Color().setHSL(hsv.h, hsv.s, hsv.l);

            scene.traverse((object) => {{
                if (object.isMesh) {{
                    object.material.color.copy(adjustedColor);
                }}
            }});
        }}

        window.addEventListener('resize', function () {{
            camera.aspect = window.innerWidth / window.innerHeight;
            camera.updateProjectionMatrix();
            renderer.setSize(window.innerWidth, window.innerHeight);
        }}, false);
    </script>
</body>
</html>
"""

# 显示 HTML 内容
display(HTML(html_content))
