In [1]:
import json
from tqdm import tqdm
import re
import torch
import pandas as pd
# 定义 color_mapping
color_mapping = {
    'red': [0.1, 0.15],
    'green': [0.15, 0.3],
    'blue': [0.3, 0.45],
    'yellow': [0.45, 0.6],
    'orange': [0.6, 0.75],
    'purple': [0.75, 0.9]
}

# 提取所有颜色
colors = color_mapping.keys()

# 检查颜色是否存在于字符串中
def find_colors_in_string(input_string):
    found_colors = [color for color in colors if color in input_string]
    return found_colors


def map_to_color(pixel):
    if pixel<0.1:
        return 'black'
    elif 0.1<=pixel<0.15:
        return 'red'
    elif 0.15<=pixel<0.3:
        return 'green'
    elif 0.3<=pixel<0.45:
        return 'blue'
    elif 0.45<=pixel<0.6:
        return 'yellow'
    elif 0.6<=pixel<0.75:
        return 'orange'
    elif 0.75<=pixel<=0.9:
        return 'purple'
    else:
        return 'other'

def compare_img(gen_img,gt_img):
    correct_pixel=0
    incorrect_pixel=0
    for i in range(len(gen_img)):
        for j in range(len(gen_img[i])):
            if map_to_color(gen_img[i][j])!=map_to_color(gt_img[i][j]):
                incorrect_pixel+=1
            else:
                correct_pixel+=1
    return correct_pixel,incorrect_pixel

acc=[]
answer_list=[i*45 for i in range(1,11)]
for answer in answer_list:
    #data_path=f'/datadrive_a/jihai/LLaVA/answer/answer-llava-v1.5-7b-mix-u-odd-{i}.jsonl'
    data_path=f'./answer/answer-llava-v1.5-7b-vq-vq-2-sw-lora-{answer}.jsonl'
    time_count=0
    time_score=0
    weather_count=0
    weather_score=0
    position_count=0
    position_score=0
    battery_count=0
    battery_score=0
    with open (data_path, "r") as f:
        for line in tqdm(f):
            json_obj = json.loads(line.strip())
            ground_truth=json_obj['groun_truth']
            answer=json_obj['answer']
            prompt=json_obj['prompt']
            if prompt[0]!='<':
                continue
            if ' ' not in ground_truth:
                if ':' in ground_truth:
                    time_count+=1
                    pattern = r"(\d{2}):(\d{2}):(\d{2})"
                    match = re.search(pattern, ground_truth)
                    gt_h = int(match.group(1))
                    gt_m = int(match.group(2))
                    gt_s = int(match.group(3))
                    match = re.search(pattern, answer)
                    if match:
                        ans_h = int(match.group(1))
                        ans_m = int(match.group(2))
                        ans_s = int(match.group(3))
                        err_h=abs(ans_h - gt_h)
                        err_h=min(err_h,12-err_h)
                        err_m=abs(ans_m - gt_m)
                        err_m=min(err_m,60-err_m)
                        err_s=abs(ans_s - gt_s)
                        err_s=min(err_s,60-err_s)
                        err=(err_h/6.0 + err_m/30.0 + err_s/30.0)/3.0
    
                        time_score+=1-err
                elif 'sunny' in ground_truth or 'raining' in ground_truth or 'cloudy' in ground_truth:
                    weather_count+=1
                    if ground_truth in answer:
                        weather_score+=1
                elif '-' in ground_truth:
                    position_count+=1
                    if ground_truth in answer:
                        position_score+=1
                elif '%' in ground_truth:
                    battery_count+=1
                    match = re.search(r'\b(100|[1-9]\d?|0)%', ground_truth)
                    gt=int(match.group(1)) / 100
                    match = re.search(r'\b(100|[1-9]\d?|0)%', answer)
                    if match:
                        ans=int(match.group(1)) / 100
                        err=abs(ans - gt)
                        battery_score+=1-err
                else:
                    raise ValueError(f"Unknown ground truth format: {ground_truth}")
    time_acc=time_score/time_count
    weather_acc=weather_score/weather_count
    position_acc=position_score/position_count
    battery_acc=battery_score/battery_count
    total_acc=(time_score+weather_score+position_score+battery_score)/(time_count+weather_count+position_count+battery_count)
    acc.append([time_acc,weather_acc,position_acc,battery_acc,total_acc])
    print(f"time_acc: {time_acc:.4f}, weather_acc: {weather_acc:.4f}, position_acc: {position_acc:.4f}, battery_acc: {battery_acc:.4f}, total_acc: {total_acc:.4f}")
acc=torch.Tensor(acc)
acc=acc.permute(1,0)
print(f"time_acc:{acc[0]}")
print(f"weather_acc:{acc[1]}")
print(f"position_acc:{acc[2]}")
print(f"battery_acc:{acc[3]}")
print(f"total_acc:{acc[4]}")

columns = [f"{i*10}%" for i in range(1, 1+acc.shape[1])]
row_index = ['time_acc', 'weather_acc', 'position_acc', 'battery_acc', 'total_acc']
df = pd.DataFrame(data=acc.numpy(),
                  index=row_index,
                  columns=columns)
df.to_csv("output.csv", index=True, header=True)



In [17]:
import json
data_path='/public_data/jihai/data/multimodalout/smart_watch_train_120ku_180km.json'
data_list=json.load(open(data_path, "r"))
print(len(data_list))

In [None]:
import matplotlib.pyplot as plt
import random
import os
from PIL import Image, ImageDraw, ImageFont,ImageColor

# sample=torch.load('/public_data/jihai/understanding/scripts/v1_5/answer/answer-llava-v1.5-7b-siglip-vq-sw-lora-450-image/11.pt')
# print(sample.shape)
# plt.imshow(sample)

gt_image_folder='/public_data/jihai/data/multimodalout/smart_watch_image_test'

answer_list=[i*45 for i in range(1,11)]
answer=answer_list[-1]
#data_path=f'/datadrive_a/jihai/LLaVA/answer/answer-llava-v1.5-7b-mix-u-odd-{i}.jsonl'
data_path=f'./answer/answer-llava-v1.5-7b-vq-vq-2-sw-lora-{answer}.jsonl'
count=0
time_count=0
weather_count=0
position_count=0
battery_count=0
with open (data_path, "r") as f:
    for line in tqdm(f):
        json_obj = json.loads(line.strip())
        ground_truth=json_obj['groun_truth']
        answer=json_obj['answer']
        prompt=json_obj['prompt']
        if prompt[0]=='<':
            continue
        if ':' in prompt:
            time_count+=1
            # pattern = r"(\d{2}):(\d{2}):(\d{2})"
            # match = re.search(pattern, prompt)
            # gt_h = int(match.group(1))
            # gt_m = int(match.group(2))
            # gt_s = int(match.group(3))
           
        if 'sunny' in prompt or 'raining' in prompt or 'cloudy' in prompt:
            weather_count+=1
            
        if '-' in prompt:
            position_count+=prompt.count('-')
           
        if '%' in prompt:
            battery_count+=1
            match = re.search(r'\b(100|[1-9]\d?|0)%', prompt)
            gt=int(match.group(1)) / 100
            
        gt_image_path=os.path.join(gt_image_folder,json_obj['groun_truth_img_tensor'])
        gt_image=Image.open(gt_image_path)
        gen_image=torch.load(json_obj['output_img_file'])
        
        # plt.title(prompt)
        # plt.imshow(gt_image)
        # plt.imshow(gen_image)
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))  # figsize 可根据需要调整

        # 设置整个图的标题
        # fig.suptitle(prompt, fontsize=16)  # 总标题
        print(prompt)


        # 左边子图：显示 gt_image
        axes[0].imshow(gt_image)
        axes[0].set_title("Ground Truth Image")  # 子图标题
        axes[0].axis('off')  # 关闭坐标轴

        # 右边子图：显示 gen_image
        axes[1].imshow(gen_image)
        axes[1].set_title("Generated Image")  # 子图标题
        axes[1].axis('off')  # 关闭坐标轴

        # 调整布局以避免重叠
        plt.tight_layout(rect=[0, 0, 1, 0.9])  # 为总标题留出空间

        # 显示图像
        plt.savefig(f'./calculate/{count}.png')
        
        count+=1
        if count==30:
            break
print(f"time_count:{time_count}")
print(f"weather_count:{weather_count}")
print(f"position_count:{position_count}")
print(f"battery_count:{battery_count}")

In [9]:
import re
g="<s> 00:50:38</s>"
pattern = r"(\d{2}):(\d{2}):(\d{2})"
match = re.search(pattern, g)
print(match.group(1))

In [2]:
196*2/3