# Using Gemini to identiy Motif
## Function Definition

In [None]:
import pandas as pd
from google import genai
from PIL import Image
import io
import numpy as np
import time
import matplotlib.pyplot as plt
from IPython.display import display
from holoviews.operation import threshold
from menuinst.platforms.win_utils.knownfolders import folder_path


#Read Image
def load_image(image_path):
    with open(image_path, "rb") as img_file:
        image = Image.open(io.BytesIO(img_file.read()))
        if image_path.endswith(".tif"):
            # change to numpy array
            img_array = np.array(image)
            
          
            img_8bit = (255 * (img_array / img_array.max())).astype(np.uint8)
            
            # Change to Image dtype
            image = Image.fromarray(img_8bit)
        image=image.resize((1024,1024))
        return image

#Identify Motif
def identify_motif(image_path, name, description, max_retries=10):
    image = load_image(image_path)
    question=f"""
This is a silk image named {name}. There are also some description about this silk: {description}. Please identify all the motifs in this Chinese silk image and mark their locations with bounding boxes using the coordinates of the top-left and bottom-right corners. Bounding boxes should aim to fully enclose the motif. Please return a list includes the bounding boxes and corresponding labels.
"""
    # display(image)
    
    #Add your API Key Here
    token=""
    client = genai.Client(api_key=token)
    
    retries = 0
    wait_time = 3 # Waiting time
    while retries < max_retries:
       try:
            response = client.models.generate_content(
                model="gemini-2.0-flash",
                contents=[question, image],
            )
            return response.text
       
       #Retry when error
       except Exception as e:
            error_msg = str(e)
            if "429" in error_msg or "rate limit" in error_msg.lower():
                print(f"请求受限，等待 {wait_time} 秒后重试...（第 {retries + 1} 次）")
                time.sleep(wait_time)
                retries += 1
            else:
                print(f"请求失败: {e}")
                break

    print("达到最大重试次数，未能完成请求。")
    return None

In [2]:
import re
import json
import matplotlib.patches as patches

# Obtain the bounding boxes on the image
def identify_motifs(image_path,results):
    #obtain list
    image = load_image(image_path)
    image = image.resize((1024, 1024))
    match = re.search(r'```json\s*(\[.*?\])\s*```', results, re.DOTALL)
    if match:
        json_str = match.group(1)
        detections = json.loads(json_str)  # 变成 Python list of dicts
    
       
        return detections
    else:
        print("No JSON found.")
        return None



In [4]:
#Motif classify into 6 categories
def classify_motif(labels, max_retries=5):
    question=f"""
Here are  a list of the motif’s names {labels}. Please classify them into 6 categories:
[0: "animals",
1: "plants",
2: "inanimate objects",
3: "geometric patterns",
4: "textual motifs",
5: "human motifs",
]
Please return a list only contains the number.
"""
    # display(image)
    retries = 0
    wait_time = 3 # waiting time
    token=""
    client = genai.Client(api_key=token)
    while retries < max_retries:
       try:
            response = client.models.generate_content(
        model="gemini-2.0-flash",
        contents=question,
    )
            return response.text
       except Exception as e:
            error_msg = str(e)
            if "429" in error_msg or "rate limit" in error_msg.lower():
                print(f"请求受限，等待 {wait_time} 秒后重试...（第 {retries + 1} 次）")
                time.sleep(wait_time)
                retries += 1
            else:
                print(f"请求失败: {e}")
                break

    print("达到最大重试次数，未能完成请求。")
    return None

    


## Process the data from the book

In [9]:
import os

#Path for all image
folder_path="Data/Book_Dataset/Book_Dataset/Yuan"
#Output dataset
df = pd.DataFrame(columns=["ID", "boxes","labels"])
#Description csv
metadata=pd.read_csv("Data/SilkPatternCollection.csv",encoding="utf-8-sig")
output_path = "output_yuan.csv"

# If output.csv exists, load the processed IDs.
if os.path.exists(output_path):
    processed_df = pd.read_csv(output_path, encoding="utf-8-sig")
    processed_ids = set(processed_df["ID"].values)
else:
    processed_ids = set()
    
for image_path in os.listdir(folder_path):
    #Extract description based on ID
    image_name = os.path.splitext(image_path)[0]
    if image_name in processed_ids:
        print(f"跳过已处理：{image_name}")
        continue
    print(image_name)
    imagepath=os.path.join(folder_path,image_path)
    meta_data=metadata[metadata['ID'].values[:]==image_name]
    results = identify_motif(imagepath, meta_data['Name_zh'].values, meta_data['Description'].values)
    #Extract bound boxes
    boxes=identify_motifs(imagepath,results)
    
    if boxes is not None:
        labels=pd.DataFrame(boxes)
        labels=labels["label"]
        #motif classification
        res2=classify_motif(labels)
        
        
        newrow=pd.DataFrame([{"ID":image_name,"boxes":boxes,"labels":res2}])
        df = pd.concat([df, newrow], ignore_index=True)
    else:
        newrow=pd.DataFrame([{"ID":image_name,"boxes":boxes,"labels":None}])
        df = pd.concat([df, newrow], ignore_index=True)
    newrow.to_csv("output_yuan.csv", mode='a', index=False, header=not os.path.exists("output.csv"), encoding="utf-8-sig")

            
    

跳过已处理：Yuan001
跳过已处理：Yuan002
跳过已处理：Yuan002_detailed
跳过已处理：Yuan003
跳过已处理：Yuan004
跳过已处理：Yuan005
跳过已处理：Yuan006
跳过已处理：Yuan007
跳过已处理：Yuan008
跳过已处理：Yuan009_restored
跳过已处理：Yuan010
跳过已处理：Yuan011
跳过已处理：Yuan012
跳过已处理：Yuan013
跳过已处理：Yuan014
跳过已处理：Yuan015
跳过已处理：Yuan016
跳过已处理：Yuan017
跳过已处理：Yuan018
跳过已处理：Yuan019
跳过已处理：Yuan020
跳过已处理：Yuan021
跳过已处理：Yuan022
跳过已处理：Yuan023
跳过已处理：Yuan024
跳过已处理：Yuan025
跳过已处理：Yuan026
跳过已处理：Yuan027
跳过已处理：Yuan028
跳过已处理：Yuan029
跳过已处理：Yuan030
跳过已处理：Yuan031
跳过已处理：Yuan032
跳过已处理：Yuan033
跳过已处理：Yuan034
跳过已处理：Yuan035_restored
跳过已处理：Yuan036_restored
跳过已处理：Yuan037
跳过已处理：Yuan037_detailed
跳过已处理：Yuan038
跳过已处理：Yuan039
跳过已处理：Yuan040
跳过已处理：Yuan041
跳过已处理：Yuan042
跳过已处理：Yuan043_restored
跳过已处理：Yuan044
跳过已处理：Yuan045_restored
跳过已处理：Yuan046
跳过已处理：Yuan046_restored
跳过已处理：Yuan047
跳过已处理：Yuan048
跳过已处理：Yuan049_restored
跳过已处理：Yuan050
跳过已处理：Yuan051
跳过已处理：Yuan052
跳过已处理：Yuan053
跳过已处理：Yuan054
跳过已处理：Yuan055_restored
跳过已处理：Yuan056
跳过已处理：Yuan057_restored
跳过已处理：Yuan058_restored
跳过已处理：Yuan059_restored
跳过已处理：Yuan060_r

## Process the data from museum

In [6]:
import os

#存放所有图片的路径
folder_path="Data/Museum/Yuan/Yuan"
#输出的dataframe
df = pd.DataFrame(columns=["ID", "boxes","labels"])
#Description所在的csv
metadata=pd.read_csv("metadata_silk_museum.csv",encoding="utf-8-sig")
output_path = "museum_yuan.csv"

# 如果存在 output.csv，加载已处理 ID
if os.path.exists(output_path):
    processed_df = pd.read_csv(output_path, encoding="utf-8-sig")
    processed_ids = set(processed_df["ID"].values)
else:
    processed_ids = set()
    
for image_path in os.listdir(folder_path):
    #根据对应的名字提取description
    image_name = os.path.splitext(image_path)[0]
    if image_name in processed_ids:
        print(f"跳过已处理：{image_name}")
        continue
    print(image_name)
    imagepath=os.path.join(folder_path,image_path)
    meta_data=metadata[metadata['name'].values[:]==image_name]
    results = identify_motif(imagepath, meta_data['name'].values, meta_data['description'].values)
    #提取bound boxes
    boxes=identify_motifs(imagepath,results)
    
    if boxes is not None:
        labels=pd.DataFrame(boxes)
        labels=labels["label"]
        #motif分类
        res2=classify_motif(labels)
        
        
        newrow=pd.DataFrame([{"ID":image_name,"boxes":boxes,"labels":res2}])
        df = pd.concat([df, newrow], ignore_index=True)
    else:
        newrow=pd.DataFrame([{"ID":image_name,"boxes":boxes,"labels":None}])
        df = pd.concat([df, newrow], ignore_index=True)
    newrow.to_csv("museum_yuan.csv", mode='a', index=False, header=not os.path.exists("output.csv"), encoding="utf-8-sig")

            
    

跳过已处理：万字菱格杂宝小花纹绮 
跳过已处理：云凤纹绫 
元云头龙凤杂宝纹暗花绫半臂袍
元代刺绣枕顶
元代墨绿回纹地雁衔花枝纹绫 
No JSON found.
元团窠对鸟纹彩锦 
元姑姑冠花绮抹额 
元牡丹花绫残片 
No JSON found.
元笠帽 
No JSON found.
元纳石失织金锦桃形窠姑姑冠冠披 
元纳石失织金锦袍衣襟 
元绢帽 
No JSON found.
元缂丝缠枝牡丹 
元褐色地卷草纹锦 
No JSON found.
元铠甲琐纹织金锦 
元黄色素绢 
No JSON found.
元龟背地团花锦马甲 
凤穿牡丹纹织金锦 
印金罗短袖衫 
印金罗长袖衫 
No JSON found.
印金罗长袖衫印金罗短袖衫_1
No JSON found.
印金罗长袖衫印金罗短袖衫_2
双头鸟纳石失 
No JSON found.
双胜纹妆花纱襟缘
对鸟纹彩锦
暗花绫腰线袍
No JSON found.
滴珠兔纹织金锦 
球路纹朵花绢 
瓣窠纳石失
红绢片
红色泰罗
纳石失靴套六出地格力芬绫裤 
织金奔鹿 
织金绫大袖袍 
No JSON found.
绉纱结 
No JSON found.
绢片
绮地印金石榴纹襟缘 
No JSON found.
绿色花卉纹绫印金方搭纹半臂 
缂丝天鹿云肩残片 
缂丝玉兔云肩残片
缠枝花卉缎残片
No JSON found.
罗地印金杂花襟缘 
No JSON found.
罗带 
No JSON found.
肩襕纹样纳石失
花卉纹罗袍
菱地飞鸟纹绫海青衣
菱格龙纹绮 
No JSON found.
蒙元六出地格力芬花绫裤
蒙元暗花绫腰线袍_1
No JSON found.
蒙元暗花绫腰线袍_2
No JSON found.
蒙元纳石失靴套 
褐色织金绢
No JSON found.
辫线袍
锦片
锦缘绢袍
No JSON found.
黄地花卉纹绫印金卧兽纹对襟上衣
黄绢残片
龟背地团花锦 


## Clean the output

In [5]:
import ast
import re
def extract_list(s):
    if not isinstance(s, str):
        return []
    match = re.search(r'\[(.*?)\]', s)
    if match:
        content = match.group(0)
        try:
            return ast.literal_eval(content)
        except:
            return [item.strip() for item in match.group(1).split(',') if item.strip()]
    else:
        return []

In [14]:
df=pd.read_csv("museum_song.csv",encoding="utf-8-sig")
print(df.columns)
df["labels"]=df["labels"].apply(extract_list)

Index(['ID', 'boxes', 'labels'], dtype='object')


In [15]:
df.to_csv("museum_song1.csv",encoding="utf-8-sig",index=False)