In [1]:
pip install openai

^C
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.10/site-packages/pip/__main__.py", line 24, in <module>
    sys.exit(_main())
  File "/opt/conda/lib/python3.10/site-packages/pip/_internal/cli/main.py", line 77, in main
    command = create_command(cmd_name, isolated=("--isolated" in cmd_args))
  File "/opt/conda/lib/python3.10/site-packages/pip/_internal/commands/__init__.py", line 114, in create_command
    module = importlib.import_module(module_path)
  File "/opt/conda/lib/python3.10/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen i

# 先让大模型对图片进行描述，再用提示词输出标签（两步prompt）

In [2]:
import torch
from torchvision import datasets, transforms
import openai
import base64
import requests

import io
from PIL import Image
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 设置API密钥
openai.api_key = ' '
# 预定义Fashion MNIST的标签
fashion_labels = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]


# 定义图像转换：将图像转换为Tensor，并进行标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为Tensor
    transforms.Normalize((0.5,), (0.5,))  # 标准化处理
])

# 加载Fashion MNIST测试数据
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
# 创建数据加载器，将数据按批次加载，并禁用随机打乱（shuffle）
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# 限制处理图像的数量
num_samples = 20  # 控制处理图片的数量以节约成本，可以根据需要调整

# 获取自定义提示文本的函数
def get_prompt():
    # 输入第一次提示的文本
    first_prompt = input("Please enter the first text prompt: ")

    # 输入第二次提示的文本
    second_prompt = input("Please enter the second text prompt for accuracy computation: ")

    return first_prompt, second_prompt

# 定义图像编码函数
def encode_image(image_path):
    buffered = io.BytesIO()
    image_path.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

# 定义分类函数
def classify_image(image_tensor, prompt_text):
    # 将图像Tensor转换为PIL图像对象
    pil_image = transforms.ToPILImage()(image_tensor[0])
    
    # 将图像编码为base64
    base64_image = encode_image(pil_image)

    # 构建请求头和负载
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai.api_key}"
    }

    payload = {
        "model": "gpt-4o-2024-08-06",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt_text  # 使用输入的提示文本
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image}"
                        }
                    }
                ]
            }
        ],
        "max_tokens": 300
    }

    try:
        # 调用OpenAI API
        response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
        
        if response.status_code == 200:
            # 获取模型的回复
            model_response = response.json()
            model_content = model_response['choices'][0]['message']['content'].strip()
            print(f"Model response: {model_content}")

            # 逐个标签进行匹配
            for i, label in enumerate(fashion_labels):
                if label.lower() in model_content.lower():  # 如果标签在响应中出现
                    return i  # 返回匹配的标签索引
        else:
            print(f"Request failed with status: {response.status_code}, {response.text}")
    
    except Exception as e:
        print(f"Error processing image: {e}")  # 打印错误信息

    return -1  # 返回错误标识

# 获取用户提示输入
first_prompt, second_prompt = get_prompt()

# 存储预测结果和值得评估的真实标签
y_pred = []  # 存储预测结果
y_true = []  # 存储真实标签

# 遍历测试加载器中的样本
for idx, (image, label) in enumerate(test_loader):
    if idx >= num_samples:  # 如果达到样本数量限制，则停止
        break
    predicted_label = classify_image(image, first_prompt)  # 对图像进行第一次分类
    y_pred.append(predicted_label)  # 存储预测标签
    y_true.append(label.item())  # 存储真实标签

# 重新评估用第二个提示进行的预测
y_pred_second = []

for idx, (image, label) in enumerate(test_loader):
    if idx >= num_samples:
        break
    predicted_label = classify_image(image, second_prompt)
    y_pred_second.append(predicted_label)

# 计算准确率
accuracy = accuracy_score(y_true, y_pred_second)
print(f"Accuracy with second prompt: {accuracy}")  # 打印使用第二次提示的准确率

# 可视化一些样本图像和预测结果
for i in range(num_samples):
    # 将Tensor图像转换为二维图像并显示
    plt.imshow(test_dataset[i][0].permute(1, 2, 0).numpy().squeeze(), cmap='gray')
    true_label = fashion_labels[y_true[i]]  # 获取真实标签
    predicted_label = fashion_labels[y_pred_second[i]] if y_pred_second[i] != -1 else "Unknown"  # 获取预测标签
    plt.title(f'True: {true_label}, Pred: {predicted_label}')  # 设置标题为真实和预测标签
    plt.show()  # 显示图像


ModuleNotFoundError: No module named 'openai'

# 第一步给一张图片作为prompt，第二步给文本prompt要求分类，以第二步的输出进行准确度计算（两步不同输入的prompt）

In [None]:

import torch
from torchvision import datasets, transforms
import openai
import base64
import requests

import io
from PIL import Image
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 设置API密钥
openai.api_key = ' '

# 预定义Fashion MNIST标签
fashion_labels = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# 定义图像转换，将图像转换为Tensor并标准化
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载Fashion MNIST测试数据
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# 限制处理图像的数量
num_samples = 20

# 函数：将图像编码为base64
def encode_image(image):
    if image.mode == 'RGBA':
        image = image.convert('RGB')
    
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

# 第一步：输入自定义图片作为概念理解的预置
def get_image_prompt():
    image_path = input("Enter the path of the image to use as a conceptual prompt: ")
    try:
        image = Image.open(image_path)
        return encode_image(image)
    except Exception as e:
        print(f"Could not open image: {e}")
        return None

# 第二步：输入自定义文本prompt
def get_text_prompt():
    return input("Please enter your text prompt: ")

# 函数：使用API进行图像分类
def classify_image(image_tensor, prompt_text):
    pil_image = transforms.ToPILImage()(image_tensor[0])
    base64_image = encode_image(pil_image)

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai.api_key}"
    }

    messages_content = [
        {
            "type": "text",
            "text": prompt_text
        },
        {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image}"
            }
        }
    ]

    payload = {
        "model": "gpt-4o-2024-08-06",
        "messages": [
            {
                "role": "user",
                "content": messages_content
            }
        ],
        "max_tokens": 300
    }

    try:
        response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
        
        if response.status_code == 200:
            model_response = response.json()
            model_content = model_response['choices'][0]['message']['content'].strip()
            print(f"Model response: {model_content}")

            for i, label in enumerate(fashion_labels):
                if label.lower() in model_content.lower():
                    return i
        else:
            print(f"Request failed with status: {response.status_code}, {response.text}")

    except Exception as e:
        print(f"Error processing image: {e}")

    return -1

# 存储预测结果和真实标签
y_pred = []
y_true = []

# 获取图像prompt
base64_image_prompt = get_image_prompt()
if not base64_image_prompt:
    print("Image prompt not obtained, exiting program.")
    exit()

# 将图像prompt用于概念理解，并展示大模型的回复
headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {openai.api_key}"
}

concept_payload = {
    "model": "gpt-4o-2024-08-06",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{base64_image_prompt}"
                    }
                },
                {
                    "type": "text",
                    "text": "Please use this image to understand the concept for further classification tasks."
                }
            ]
        }
    ],
    "max_tokens": 300
}

try:
    concept_response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=concept_payload)
    if concept_response.status_code == 200:
        concept_model_response = concept_response.json()
        concept_model_content = concept_model_response['choices'][0]['message']['content'].strip()
        print(f"Conceptual understanding response: {concept_model_content}")
    else:
        print(f"Concept setup failed with status: {concept_response.status_code}, {concept_response.text}")
except Exception as e:
    print(f"Error setting concept: {e}")

# 获取用户定义的文本提示词
prompt_text = get_text_prompt()

# 遍历测试加载器中的样本
for idx, (image, label) in enumerate(test_loader):
    if idx >= num_samples:
        break
    predicted_label = classify_image(image, prompt_text)
    y_pred.append(predicted_label)
    y_true.append(label.item())

# 计算准确率
accuracy = accuracy_score(y_true, y_pred)
print(f"Accuracy: {accuracy}")

# 可视化一些样本图像和预测结果
for i in range(num_samples):
    plt.imshow(test_dataset[i][0].permute(1, 2, 0).numpy().squeeze(), cmap='gray')
    true_label = fashion_labels[y_true[i]]
    predicted_label = fashion_labels[y_pred[i]] if y_pred[i] != -1 else "Unknown"
    plt.title(f'True: {true_label}, Pred: {predicted_label}')
    plt.show()

# 对每种label提供描述（只有一步prompt）

In [None]:
import torch
from torchvision import datasets, transforms
import openai
import base64
import requests

import io
from PIL import Image
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 设置API密钥
openai.api_key = '1'
# 预定义Fashion MNIST的标签
fashion_labels = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# 定义图像转换：将图像转换为Tensor，并进行标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为Tensor
    transforms.Normalize((0.5,), (0.5,))  # 标准化处理
])

# 加载Fashion MNIST测试数据
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
# 创建数据加载器，将数据按批次加载，并禁用随机打乱（shuffle）
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# 限制处理图像的数量
num_samples = 20  # 控制处理图片的数量以节约成本，可以根据需要调整

# 获取自定义提示文本的函数
def get_prompt():
    # 这里你可以修改提示文本
    return "can you help me classify this picture from 'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot', please first tell what shape is inside this picture, then tell what features are these shapes, then generate a label"

# 定义图像编码函数
def encode_image(image_path):
    buffered = io.BytesIO()
    image_path.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

# 定义分类函数
def classify_image(image_tensor):
    # 将图像Tensor转换为PIL图像对象
    pil_image = transforms.ToPILImage()(image_tensor[0])
    
    # 将图像编码为base64
    base64_image = encode_image(pil_image)

    # 构建请求头和负载
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai.api_key}"
    }

    payload = {
        "model": "gpt-4o-2024-08-06",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": get_prompt()  # 使用自定义提示文本
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image}"
                        }
                    }
                ]
            }
        ],
        "max_tokens": 300
    }

    try:
        # 调用OpenAI API
        response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
        
        if response.status_code == 200:
            # 获取模型的回复
            model_response = response.json()
            model_content = model_response['choices'][0]['message']['content'].strip()
            print(f"Model response: {model_content}")

            # 逐个标签进行匹配
            for i, label in enumerate(fashion_labels):
                if label.lower() in model_content.lower():  # 如果标签在响应中出现
                    return i  # 返回匹配的标签索引
        else:
            print(f"Request failed with status: {response.status_code}, {response.text}")
    
    except Exception as e:
        print(f"Error processing image: {e}")  # 打印错误信息

    return -1  # 返回错误标识

# 存储预测结果和值得评估的真实标签
y_pred = []  # 存储预测结果
y_true = []  # 存储真实标签

# 遍历测试加载器中的样本
for idx, (image, label) in enumerate(test_loader):
    if idx >= num_samples:  # 如果达到样本数量限制，则停止
        break
    predicted_label = classify_image(image)  # 对图像进行分类
    y_pred.append(predicted_label)  # 存储预测标签
    y_true.append(label.item())  # 存储真实标签

# 计算准确率
accuracy = accuracy_score(y_true, y_pred)
print(f"Accuracy: {accuracy}")  # 打印准确率

# 可视化一些样本图像和预测结果
for i in range(num_samples):
    # 将Tensor图像转换为二维图像并显示
    plt.imshow(test_dataset[i][0].permute(1, 2, 0).numpy().squeeze(), cmap='gray')
    true_label = fashion_labels[y_true[i]]  # 获取真实标签
    predicted_label = fashion_labels[y_pred[i]] if y_pred[i] != -1 else "Unknown"  # 获取预测标签
    plt.title(f'True: {true_label}, Pred: {predicted_label}')  # 设置标题为真实和预测标签
    plt.show()  # 显示图

In [None]:
import torch
from torchvision import datasets, transforms
import openai
import base64
import requests

import io
from PIL import Image
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 设置API密钥
openai.api_key = ' '

# 预定义Fashion MNIST标签
fashion_labels = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# 定义图像转换，将图像转换为Tensor并标准化
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载Fashion MNIST测试数据
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# 限制处理图像的数量
num_samples = 20

# 函数：将图像编码为base64
def encode_image(image):
    if image.mode == 'RGBA':
        image = image.convert('RGB')

    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

# 函数：使用API进行图像分类
def classify_image(image_tensor, prompt_text):
    pil_image = transforms.ToPILImage()(image_tensor[0])
    base64_image = encode_image(pil_image)

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai.api_key}"
    }

    messages_content = [
        {
            "type": "text",
            "text": prompt_text
        },
        {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image}"
            }
        }
    ]

    payload = {
        "model": "gpt-4o-2024-08-06",
        "messages": [
            {
                "role": "user",
                "content": messages_content
            }
        ],
        "max_tokens": 300
    }

    try:
        response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
        
        if response.status_code == 200:
            model_response = response.json()
            model_content = model_response['choices'][0]['message']['content'].strip()
            print(f"Model response: {model_content}")

            for i, label in enumerate(fashion_labels):
                if label.lower() in model_content.lower():
                    return i
        else:
            print(f"Request failed with status: {response.status_code}, {response.text}")

    except Exception as e:
        print(f"Error processing image: {e}")

    return -1

# 再次分类不正确答案的函数
def reclassify_incorrect(y_pred, y_true, images, incorrect_prompt):
    y_pred_rechecked = y_pred.copy()
    for idx, predicted in enumerate(y_pred):
        if predicted != y_true[idx]:
            print(f"Re-evaluating index {idx}: Original prediction was incorrect.")
            y_pred_rechecked[idx] = classify_image(images[idx], incorrect_prompt)
    return y_pred_rechecked

# 获取用户输入的重新分类提示
def get_reclassification_prompt():
    return input("Enter the reclassification prompt for incorrect answers: ")

# 存储预测结果和真实标签
y_pred = []
y_true = []
images = []  # 存储图像以便于重新分类

# 获取用户定义的文本提示词（用于初次分类）
prompt_text = input("Please enter your text prompt: ")

# 获取用户输入的重新分类提示
reclassification_prompt = get_reclassification_prompt()

# 遍历测试加载器中的样本
for idx, (image, label) in enumerate(test_loader):
    if idx >= num_samples:
        break
    images.append(image)
    predicted_label = classify_image(image, prompt_text)
    y_pred.append(predicted_label)
    y_true.append(label.item())

# 计算第一轮准确率
accuracy_round1 = accuracy_score(y_true, y_pred)
print(f"Accuracy Round 1: {accuracy_round1}")

# 第二轮：使用用户输入的提示对不正确的答案进行重新分类
y_pred_round2 = reclassify_incorrect(y_pred, y_true, images, reclassification_prompt)
accuracy_round2 = accuracy_score(y_true, y_pred_round2)
print(f"Accuracy Round 2: {accuracy_round2}")

# 第三轮：再次使用相同的用户输入提示进行再次重新分类
y_pred_round3 = reclassify_incorrect(y_pred_round2, y_true, images, reclassification_prompt)
accuracy_round3 = accuracy_score(y_true, y_pred_round3)
print(f"Accuracy Round 3: {accuracy_round3}")

# 可视化一些样本图像和最终的预测结果
for i in range(num_samples):
    plt.imshow(test_dataset[i][0].permute(1, 2, 0).numpy().squeeze(), cmap='gray')
    true_label = fashion_labels[y_true[i]]
    predicted_label = fashion_labels[y_pred_round3[i]] if y_pred_round3[i] != -1 else "Unknown"
    plt.title(f'True: {true_label}, Pred: {predicted_label}')
    plt.show()

# 直接分类，如果错了，告诉大模型有错误，让他重新进行二轮三轮（两步prompt）

In [None]:
import torch
from torchvision import datasets, transforms
import openai
import base64
import requests

import io
from PIL import Image
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 设置API密钥
openai.api_key = ' '

# 预定义Fashion MNIST标签
fashion_labels = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# 定义图像转换，将图像转换为Tensor并标准化
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载Fashion MNIST测试数据
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# 限制处理图像的数量
num_samples = 20

# 函数：将图像编码为base64
def encode_image(image):
    if image.mode == 'RGBA':
        image = image.convert('RGB')

    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

# 函数：使用API进行图像分类
def classify_image(image_tensor, prompt_text):
    pil_image = transforms.ToPILImage()(image_tensor[0])
    base64_image = encode_image(pil_image)

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai.api_key}"
    }

    messages_content = [
        {
            "type": "text",
            "text": prompt_text
        },
        {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image}"
            }
        }
    ]

    payload = {
        "model": "gpt-4o-2024-08-06",
        "messages": [
            {
                "role": "user",
                "content": messages_content
            }
        ],
        "max_tokens": 300
    }

    try:
        response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
        
        if response.status_code == 200:
            model_response = response.json()
            model_content = model_response['choices'][0]['message']['content'].strip()
            print(f"Model response: {model_content}")

            for i, label in enumerate(fashion_labels):
                if label.lower() in model_content.lower():
                    return i
        else:
            print(f"Request failed with status: {response.status_code}, {response.text}")

    except Exception as e:
        print(f"Error processing image: {e}")

    return -1

# 再次分类不正确答案的函数
def reclassify_incorrect(y_pred, y_true, images, incorrect_prompt):
    y_pred_rechecked = y_pred.copy()
    for idx, predicted in enumerate(y_pred):
        if predicted != y_true[idx]:
            print(f"Re-evaluating index {idx}: Original prediction was incorrect.")
            y_pred_rechecked[idx] = classify_image(images[idx], incorrect_prompt)
    return y_pred_rechecked

# 获取用户输入的重新分类提示
def get_reclassification_prompt():
    return input("Enter the reclassification prompt for incorrect answers: ")

# 存储预测结果和真实标签
y_pred = []
y_true = []
images = []  # 存储图像以便于重新分类

# 获取用户定义的文本提示词（用于初次分类）
prompt_text = input("Please enter your text prompt: ")

# 获取用户输入的重新分类提示
reclassification_prompt = get_reclassification_prompt()

# 遍历测试加载器中的样本
for idx, (image, label) in enumerate(test_loader):
    if idx >= num_samples:
        break
    images.append(image)
    predicted_label = classify_image(image, prompt_text)
    y_pred.append(predicted_label)
    y_true.append(label.item())

# 计算第一轮准确率
accuracy_round1 = accuracy_score(y_true, y_pred)
print(f"Accuracy Round 1: {accuracy_round1}")

# 第二轮：使用用户输入的提示对不正确的答案进行重新分类
y_pred_round2 = reclassify_incorrect(y_pred, y_true, images, reclassification_prompt)
accuracy_round2 = accuracy_score(y_true, y_pred_round2)
print(f"Accuracy Round 2: {accuracy_round2}")

# 第三轮：再次使用相同的用户输入提示进行再次重新分类
y_pred_round3 = reclassify_incorrect(y_pred_round2, y_true, images, reclassification_prompt)
accuracy_round3 = accuracy_score(y_true, y_pred_round3)
print(f"Accuracy Round 3: {accuracy_round3}")

# 可视化一些样本图像和最终的预测结果
for i in range(num_samples):
    plt.imshow(test_dataset[i][0].permute(1, 2, 0).numpy().squeeze(), cmap='gray')
    true_label = fashion_labels[y_true[i]]
    predicted_label = fashion_labels[y_pred_round3[i]] if y_pred_round3[i] != -1 else "Unknown"
    plt.title(f'True: {true_label}, Pred: {predicted_label}')
    plt.show()

# 综合

In [None]:
import torch
from torchvision import datasets, transforms
import openai
import base64
import requests

import io
from PIL import Image
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 设置API密钥
openai.api_key = ' '

# 预定义Fashion MNIST标签
fashion_labels = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# 定义图像转换，将图像转换为Tensor并标准化
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载Fashion MNIST测试数据
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# 限制处理图像的数量
num_samples = 50

# 函数：将图像编码为base64
def encode_image(image):
    if image.mode == 'RGBA':
        image = image.convert('RGB')

    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

# 第一步：输入自定义图片作为概念理解的预置
def get_image_prompt():
    image_path = input("Enter the path of the image to use as a conceptual prompt: ")
    try:
        image = Image.open(image_path)
        return encode_image(image)
    except Exception as e:
        print(f"Could not open image: {e}")
        return None

# 第二步：输入自定义文本作为附加概念学习
def get_additional_concept_text():
    return input("Enter a text prompt to further teach the model some concepts: ")

# 第三步：输入用于分类的文本prompt
def get_text_prompt():
    return input("Please enter your text prompt: ")

# 获取用户输入的重新分类提示
def get_reclassification_prompt():
    return input("Enter the reclassification prompt for incorrect answers: ")

# 函数：使用API进行图像分类
def classify_image(image_tensor, prompt_text):
    pil_image = transforms.ToPILImage()(image_tensor[0])
    base64_image = encode_image(pil_image)

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai.api_key}"
    }

    messages_content = [
        {
            "type": "text",
            "text": prompt_text
        },
        {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image}"
            }
        }
    ]

    payload = {
        "model": "gpt-4o-2024-08-06",
        "messages": [
            {
                "role": "user",
                "content": messages_content
            }
        ],
        "max_tokens": 300
    }

    try:
        response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
        
        if response.status_code == 200:
            model_response = response.json()
            model_content = model_response['choices'][0]['message']['content'].strip()
            print(f"Model response: {model_content}")

            for i, label in enumerate(fashion_labels):
                if label.lower() in model_content.lower():
                    return i
        else:
            print(f"Request failed with status: {response.status_code}, {response.text}")

    except Exception as e:
        print(f"Error processing image: {e}")

    return -1

# 再次分类不正确答案的函数
def reclassify_incorrect(y_pred, y_true, images, incorrect_prompt):
    y_pred_rechecked = y_pred.copy()
    for idx, predicted in enumerate(y_pred):
        if predicted != y_true[idx]:
            print(f"Re-evaluating index {idx}: Original prediction was incorrect.")
            y_pred_rechecked[idx] = classify_image(images[idx], incorrect_prompt)
    return y_pred_rechecked

# 存储预测结果和真实标签
y_pred = []
y_true = []
images = []  # 存储图像以便于重新分类

# 获取图像prompt
base64_image_prompt = get_image_prompt()
if not base64_image_prompt:
    print("Image prompt not obtained, exiting program.")
    exit()

# 将图像prompt用于概念理解，并展示大模型的回复
headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {openai.api_key}"
}

concept_payload = {
    "model": "gpt-4o-2024-08-06",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{base64_image_prompt}"
                    }
                },
                {
                    "type": "text",
                    "text": "Please use this image to understand the concept for further classification tasks."
                }
            ]
        }
    ],
    "max_tokens": 300
}

try:
    concept_response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=concept_payload)
    if concept_response.status_code == 200:
        concept_model_response = concept_response.json()
        concept_model_content = concept_model_response['choices'][0]['message']['content'].strip()
        print(f"Conceptual understanding response: {concept_model_content}")
    else:
        print(f"Concept setup failed with status: {concept_response.status_code}, {concept_response.text}")
except Exception as e:
    print(f"Error setting concept: {e}")

# 输入附加文本概念，并展现大模型的回复
additional_concept_text = get_additional_concept_text()
additional_concept_payload = {
    "model": "gpt-4o-2024-08-06",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": additional_concept_text
                }
            ]
        }
    ],
    "max_tokens": 300
}

try:
    additional_concept_response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=additional_concept_payload)
    if additional_concept_response.status_code == 200:
        additional_concept_model_response = additional_concept_response.json()
        additional_concept_model_content = additional_concept_model_response['choices'][0]['message']['content'].strip()
        print(f"Additional concept understanding response: {additional_concept_model_content}")
    else:
        print(f"Additional concept setup failed with status: {additional_concept_response.status_code}, {additional_concept_response.text}")
except Exception as e:
    print(f"Error setting additional concept: {e}")

# 获取用户定义的文本提示词
prompt_text = get_text_prompt()

# 获取用户输入的重新分类提示
reclassification_prompt = get_reclassification_prompt()

# 遍历测试加载器中的样本
for idx, (image, label) in enumerate(test_loader):
    if idx >= num_samples:
        break
    images.append(image)
    predicted_label = classify_image(image, prompt_text)
    y_pred.append(predicted_label)
    y_true.append(label.item())

# 计算第一轮准确率
accuracy_round1 = accuracy_score(y_true, y_pred)
print(f"Accuracy Round 1: {accuracy_round1}")

# 第二轮：使用用户输入的提示对不正确的答案进行重新分类
y_pred_round2 = reclassify_incorrect(y_pred, y_true, images, reclassification_prompt)
accuracy_round2 = accuracy_score(y_true, y_pred_round2)
print(f"Accuracy Round 2: {accuracy_round2}")

# 第三轮：再次使用相同的用户输入提示进行再次重新分类
y_pred_round3 = reclassify_incorrect(y_pred_round2, y_true, images, reclassification_prompt)
accuracy_round3 = accuracy_score(y_true, y_pred_round3)
print(f"Accuracy Round 3: {accuracy_round3}")

# 可视化一些样本图像和最终的预测结果
for i in range(num_samples):
    plt.imshow(test_dataset[i][0].permute(1, 2, 0).numpy().squeeze(), cmap='gray')
    true_label = fashion_labels[y_true[i]]
    predicted_label = fashion_labels[y_pred_round3[i]] if y_pred_round3[i] != -1 else "Unknown"
    plt.title(f'True: {true_label}, Pred: {predicted_label}')
    plt.show()