In [None]:
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import numpy as np
import matplotlib.pyplot as plt

# 创建FastAPI实例
app = FastAPI()

# 设置静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")

# 设置模板目录
templates = Jinja2Templates(directory="templates")

# 图片标签字典
label_dict = {0: 'Animals', 1: 'Architecture', 2: 'Clothing', 3: 'Flower', 4: 'Food', 5: 'Portrait', 6: 'Scenery',
              7: 'Transportation', 8: 'BillReceipt', 9: 'FoodCuisine'}

# 路由：首页
@app.get("/")
def home(request: Request):
    return templates.TemplateResponse("index.html", {"request": request, "labels": label_dict.values()})

# 路由：展示图片
@app.get("/show_images")
def show_images(request: Request, label: str):
    # 获取指定标签的图片索引
    label_index = list(label_dict.values()).index(label)
    images = trainImages[trainLabels == label_index]

    # 生成图片展示HTML代码
    image_html = ""
    for image in images:
        fig, ax = plt.subplots(figsize=(2, 2))
        ax.imshow(image, cmap='gray')
        ax.axis('off')

        # 将图片转换为base64编码的字符串
        image_str = get_image_str(fig)

        # 生成HTML代码
        image_html += f'<img src="data:image/png;base64,{image_str}" alt="{label}" width="200" height="200">'

    return templates.TemplateResponse("show_images.html", {"request": request, "label": label, "images": image_html})

# 辅助函数：将图片转换为base64编码的字符串
def get_image_str(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    buf.seek(0)
    image_str = base64.b64encode(buf.read()).decode()
    buf.close()
    return image_str

