In [2]:
print("vqa")

vqa


In [2]:
%pip install -q transformers accelerate pillow tqdm

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
diffusers 0.35.2 requires importlib_metadata, which is not installed.
rembg 2.0.68 requires jsonschema, which is not installed.
ultralytics 8.3.233 requires opencv-python>=4.6.0, which is not installed.[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [6]:
%pip install huggingface_hub


Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Note: you may need to restart the kernel to use updated packages.


In [18]:
%pip install -q transformers accelerate qwen-vl-utils pillow tqdm


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [1]:
# ============================================================
# 1. 기본 세팅 및 모델 로드
# ============================================================
import os
import json
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# Hugging Face token을 이미 환경변수에 넣어뒀다면 생략 가능.
from huggingface_hub import login
#login()  # 한 번만 실행해두면 됨.

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct"

processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto"  # GPU 있으면 자동으로 올림
).eval()

# shot label space 정의 (index 0,1,2 <-> full, medium, closeup)
shot_labels = ["full", "medium", "closeup"]
label2id = {name: i for i, name in enumerate(shot_labels)}
id2label = {i: name for name, i in label2id.items()}

print(label2id)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 5/5 [00:05<00:00,  1.12s/it]


{'full': 0, 'medium': 1, 'closeup': 2}


In [2]:
# ============================================================
# 2. VQA용 질문 프롬프트 & 헬퍼 함수들
# ============================================================

VQA_QUESTION = (
    "Classify the camera shot size of the MAIN PERSON. "
    "Definitions: full = whole body visible head-to-toe; "
    "medium = from waist/chest up; "
    "closeup = face fills most of the frame. "
    "Answer with exactly one word: full, medium, or closeup."
)


def normalize_vqa_answer(ans: str) -> str:
    if ans is None:
        return "unknown"
    a = ans.strip().lower()
    last = a.splitlines()[-1].strip('",.()[] ')
    if last in ("full", "medium", "closeup"):
        return last
    if "close" in a:
        return "closeup"
    if "medium" in a:
        return "medium"
    if "full" in a:
        return "full"
    return "unknown"


@torch.no_grad()
def vqa_predict_shot(image: Image.Image, question: str = VQA_QUESTION, debug:bool=False) -> str:
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": question},
            ],
        }
    ]

    # chat template + vision preprocessing (Qwen2-VL 공식 예제 흐름) 
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt"
    )

    # device_map="auto"를 썼으면 model이 알아서 장치 분산될 수 있으니,
    # inputs는 가능한 한 첫 디바이스로 이동
    if device == "cuda":
        inputs = {k: v.to("cuda") for k, v in inputs.items() if torch.is_tensor(v)}
    else:
        inputs = {k: v.to("cpu") for k, v in inputs.items() if torch.is_tensor(v)}

    out_ids = model.generate(**inputs, max_new_tokens=10)
    raw = processor.batch_decode(out_ids, skip_special_tokens=True)[0]
    
    if debug:
        print("[RAW]", repr(raw))
        print("[NORM]", normalize_vqa_answer(raw))

    return normalize_vqa_answer(raw)



In [3]:
# ============================================================
# 3. GT json에서 shot 정보 파싱 함수
#    예: {"file_name": "SF090738.JPEG",
#         "text": "<ms_trg>, medium shot, Eye level, ..."}
# ============================================================

def parse_gt_shot_from_text(text: str) -> str:
    """
    GT json 안의 'text' 필드에서 shot 정보를 추출.
    샘플 예시 기준:
      - 'medium shot' 이라는 단어 포함
    필요하다면 여기에 규칙을 더 추가하면 됨.
    """
    t = text.lower()
    # 우선 순위: full -> medium -> closeup
    if "full shot" in t:
        return "full"
    if "medium shot" in t:
        return "medium"
    if "closeup" in t or "close-up" in t or "close up" in t:
        return "closeup"
    # 토큰 형태 (<ms_trg> 등)이 있다면 여기에 추가
    if "<fs" in t:
        return "full"
    if "<ms" in t:
        return "medium"
    if "<cu" in t:
        return "closeup"
    return "unknown"


In [4]:
# ============================================================
# 4. 경로 설정
#    (네 환경에 맞게 파일명/경로만 수정하면 됨)
# ============================================================

# A_hat용 GT 데이터
DATA_A_DIR = "/home/aikusrv01/storyboard/TK/Dataset_fin"            # GT 이미지 폴더
GT_JSON_PATH = os.path.join(DATA_A_DIR, "metadata.jsonl")  # 예시 이름, 실제 파일명에 맞게 수정

# q_hat_s용 생성 이미지 폴더들
#DATA_Q_DIRS = {
#    1: "vqa_test/dataQ1",   # S=1 (예: full shot prompt)
#    2: "vqa_test/dataQ2",   # S=2 (예: medium shot prompt)
#    3: "vqa_test/dataQ3",   # S=3 (예: closeup prompt)
#}

DATA_Q_DIR = "/home/aikusrv01/storyboard/TK/validation/output/base_sd"               # 생성 이미지 폴더 (한 곳)
Q_JSONL_PATH = "/home/aikusrv01/storyboard/TK/validation/validation.jsonl"    # 생성 이미지 메타 jsonl (파일명 맞게 수정!)


# S index와 shot label의 매핑 (너의 이론에서 s=1,2,3 순서 정해야 함)
S_TO_LABEL = {
    1: "full",
    2: "medium",
    3: "closeup",
}


In [5]:
def load_jsonl(path):
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            items.append(json.loads(line))
    return items


In [6]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"  # 여기 숫자를 바꿔서 빈 GPU로

In [9]:
import random
gt_items=load_jsonl(GT_JSON_PATH)

sample_idx = random.sample(range(len(gt_items)), 5)

for idx in sample_idx:
    item = gt_items[idx]
    file_name = item["file_name"]
    text = item.get("text", "")
    gt_shot = parse_gt_shot_from_text(text)

    img_path = os.path.join(DATA_A_DIR, file_name)
    image = Image.open(img_path).convert("RGB")

    # raw 출력하려면 vqa_predict_shot에 debug 옵션이 있는 게 편함
    pred = vqa_predict_shot(image, debug=True)

    print("-----")
    print("file:", file_name)
    print("GT:", gt_shot)
    print("mapped pred:", pred)
    print("text:", text[:150])


[RAW] 'system\nYou are a helpful assistant.\nuser\nAnalyze the MAIN CHARACTER in the image.Determine the visibility of specific body parts to classify the shot type.Return the answer in JSON format with two keys: \'visible_parts\' and \'shot_type\'.Step 1: Check visibility (Yes/No)- Are the FEET visible?- Is the WAIST or HIPS visible?- Is ONLY the HEAD/SHOULDERS visible?Step 2: Determine Shot Type based on logic- If FEET are visible -> \'full\'- If FEET are NOT visible BUT WAIST is visible -> \'medium\'- If WAIST is NOT visible (only head/shoulders) -> \'closeup\'Output format: {\'visible_parts\': [list of parts], \'shot_type\': \'answer\'}\nassistant\n```json\n{\n  "visible_parts": ["'
[NORM] closeup
-----
file: SD127389.JPEG
GT: medium
mapped pred: closeup
text: <ms_trg>, medium shot, Low angle, youth, stone face, T-shirt / sleeveless, and look to the left.
[RAW] 'system\nYou are a helpful assistant.\nuser\nAnalyze the MAIN CHARACTER in the image.Determine the visibility of specific 

KeyboardInterrupt: 

In [20]:
# ============================================================
# 5. A_hat 계산
#    - M개의 이미지 (dataA)
#    - GT 샷타입 (json)
#    - VQA 출력 샷타입
# ============================================================

# GT json 로드
gt_items=load_jsonl(GT_JSON_PATH)
gt_items = gt_items[:100]   # ★ 테스트용 (10개만)

#with open(GT_JSON_PATH, "r", encoding="utf-8") as f:
#    gt_items = json.load(f)  # 리스트라고 가정

print("Number of GT items (M):", len(gt_items))

# confusion matrix 카운트 (3x3)
conf_counts = np.zeros((3, 3), dtype=np.int64)
row_totals = np.zeros(3, dtype=np.int64)

num_skipped = 0

PRINT_EVERY = 10   # 몇 개마다 출력할지
cnt = 0

for item in tqdm(gt_items, desc="Computing A_hat with Qwen"):
    cnt += 1
    if cnt % PRINT_EVERY == 0:
        print(f"\n--- After {cnt} samples ---")
        print("Row totals:", row_totals)
        print("Confusion counts:\n", conf_counts)
        tmp_A = np.zeros_like(conf_counts, dtype=np.float64)
        for i in range(3):
            if row_totals[i] > 0:
                tmp_A[i] = conf_counts[i] / row_totals[i]
        print("A_hat (partial):\n", tmp_A)


    file_name = item["file_name"]
    text = item.get("text", "")
    gt_shot = parse_gt_shot_from_text(text)

    if gt_shot not in label2id:
        # shot 정보를 제대로 못 뽑은 경우 스킵
        num_skipped += 1
        continue

    gt_id = label2id[gt_shot]

    img_path = os.path.join(DATA_A_DIR, file_name)
    if not os.path.exists(img_path):
        print(f"[WARN] Image not found: {img_path}")
        num_skipped += 1
        continue

    image = Image.open(img_path).convert("RGB")
    pred_shot = vqa_predict_shot(image)
    if pred_shot not in label2id:
        # VQA가 엉뚱한 답(unknown)을 준 경우 스킵하거나,
        # 필요시 'unknown'이라는 4번째 클래스 추가해서 다루어도 됨.
        num_skipped += 1
        continue

    pred_id = label2id[pred_shot]

    conf_counts[gt_id, pred_id] += 1
    row_totals[gt_id] += 1

print("Skipped samples:", num_skipped)
print("Row totals (M_i):", row_totals)

# 행별로 정규화해서 A_hat 만들기
A_hat = np.zeros_like(conf_counts, dtype=np.float64)
for i in range(3):
    if row_totals[i] > 0:
        A_hat[i, :] = conf_counts[i, :] / row_totals[i]
    else:
        A_hat[i, :] = 0.0

print("Confusion counts:\n", conf_counts)
print("Estimated A_hat:\n", A_hat)


Number of GT items (M): 100


Computing A_hat with Qwen:   9%|▉         | 9/100 [04:19<38:23, 25.31s/it]


--- After 10 samples ---
Row totals: [0 3 6]
Confusion counts:
 [[0 0 0]
 [0 0 3]
 [0 0 6]]
A_hat (partial):
 [[0. 0. 0.]
 [0. 0. 1.]
 [0. 0. 1.]]


Computing A_hat with Qwen:  19%|█▉        | 19/100 [05:15<10:05,  7.47s/it]


--- After 20 samples ---
Row totals: [ 0 12  7]
Confusion counts:
 [[ 0  0  0]
 [ 0  0 12]
 [ 0  0  7]]
A_hat (partial):
 [[0. 0. 0.]
 [0. 0. 1.]
 [0. 0. 1.]]


Computing A_hat with Qwen:  29%|██▉       | 29/100 [07:15<07:38,  6.45s/it]


--- After 30 samples ---
Row totals: [ 0 19 10]
Confusion counts:
 [[ 0  0  0]
 [ 0  0 19]
 [ 0  0 10]]
A_hat (partial):
 [[0. 0. 0.]
 [0. 0. 1.]
 [0. 0. 1.]]


Computing A_hat with Qwen:  31%|███       | 31/100 [08:18<18:29, 16.08s/it]


KeyboardInterrupt: 

In [7]:
A_hat = [[0.86956522, 0.13043478, 0.        ],
 [0.32608696, 0.49130435, 0.1826087 ],
 [0.17391304, 0.32608696, 0.5       ]]
A_hat

[[0.86956522, 0.13043478, 0.0],
 [0.32608696, 0.49130435, 0.1826087],
 [0.17391304, 0.32608696, 0.5]]

In [8]:
vqa_acc_uniform = np.trace(A_hat) / 3.0
print(f"VQA accuracy (uniform GT assumption): {vqa_acc_uniform:.4f}")

VQA accuracy (uniform GT assumption): 0.6203


In [9]:
'''# ============================================================
# 6. q_hat_s 계산
#    - S=s 조건에서 생성된 N_s장 이미지 (dataQ1, dataQ2, dataQ3)
#    - 각 이미지에 대한 VQA 샷 예측
# ============================================================

# q_hat_s를 shot index s=0,1,2로 둘 수도 있지만,
# 여기서는 s = 1,2,3 (문제에서 쓰던 notation)을 그대로 사용.
# q_hat_dict[s] = 길이 3 벡터 (Z가 full/medium/closeup일 확률)

q_hat_dict = {}
N_s_dict = {}

for s in [1, 2, 3]:
    dir_path = DATA_Q_DIRS[s]
    if not os.path.isdir(dir_path):
        print(f"[WARN] Folder for S={s} not found:", dir_path)
        continue

    file_list = [
        f for f in os.listdir(dir_path)
        if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".webp"))
    ]
    file_list.sort()

    counts = np.zeros(3, dtype=np.int64)
    total = 0
    skipped = 0

    print(f"\n[ S = {s} (prompt for {S_TO_LABEL[s]}) ]")
    print(f"Number of images found: {len(file_list)}")

    for fname in tqdm(file_list, desc=f"VQA for S={s}"):
        img_path = os.path.join(dir_path, fname)
        image = Image.open(img_path).convert("RGB")
        pred_shot = vqa_predict_shot(image)

        if pred_shot not in label2id:
            skipped += 1
            continue

        pred_id = label2id[pred_shot]
        counts[pred_id] += 1
        total += 1

    print(f"S={s}: used {total}, skipped {skipped}")

    if total > 0:
        q_hat = counts / total
    else:
        q_hat = np.zeros(3, dtype=np.float64)

    q_hat_dict[s] = q_hat
    N_s_dict[s] = total

    print(f"q_hat_{s} (Z distribution | S={s}):", q_hat)'''

# ============================================================
# 6. q_hat_s 계산 (jsonl 기반: file_name + text에서 S 추출)
# ============================================================

q_items = load_jsonl(Q_JSONL_PATH)
#q_items = random.sample(q_items, 20)

print("Number of generated items:", len(q_items))

# S별 counts 저장
counts_by_s = {1: np.zeros(3, dtype=np.int64),
               2: np.zeros(3, dtype=np.int64),
               3: np.zeros(3, dtype=np.int64)}
totals_by_s = {1: 0, 2: 0, 3: 0}
skipped_by_s = {1: 0, 2: 0, 3: 0}

# text에서 "의도된 샷" = S 를 뽑기
# (A_hat에서 쓰던 parse_gt_shot_from_text를 그대로 재사용)
def parse_S_from_text(text: str) -> int:
    shot = parse_gt_shot_from_text(text)  # "full"/"medium"/"closeup"/"unknown"
    if shot not in label2id:
        return -1
    return {"full": 1, "medium": 2, "closeup": 3}[shot]

for item in tqdm(q_items, desc="Computing q_hat_s with BLIP2 (jsonl)"):
    file_name = item["file_name"]
    text = item.get("text", "")

    s = parse_S_from_text(text)
    if s not in (1, 2, 3):
        continue

    img_path = os.path.join(DATA_Q_DIR, file_name)

    if not os.path.exists(img_path):
        # 확장자만 png로 바꿔서 재시도 (JPEG/JPG/jpeg 등 모두 대응)
        base, _ = os.path.splitext(img_path)
        cand = base + ".png"
        if os.path.exists(cand):
            img_path = cand
        else:
            cand = base + ".PNG"
            if os.path.exists(cand):
                img_path = cand
            else:
                print(f"[WARN] Image not found: {img_path} (also tried .png)")
                skipped_by_s[s] += 1
                continue


    image = Image.open(img_path).convert("RGB")
    pred_shot = vqa_predict_shot(image)
    if pred_shot not in label2id:
        skipped_by_s[s] += 1
        continue

    pred_id = label2id[pred_shot]
    counts_by_s[s][pred_id] += 1
    totals_by_s[s] += 1

# 최종 q_hat_dict 생성 (Section 7에서 그대로 사용 가능)
q_hat_dict = {}
N_s_dict = {}

for s in (1, 2, 3):
    total = totals_by_s[s]
    N_s_dict[s] = total
    q_hat_dict[s] = counts_by_s[s] / total if total > 0 else np.zeros(3, dtype=np.float64)

    print(f"S={s}: used {total}, skipped {skipped_by_s[s]}")
    print(f"q_hat_{s} =", q_hat_dict[s])



Number of generated items: 747


Computing q_hat_s with BLIP2 (jsonl): 100%|██████████| 747/747 [31:30<00:00,  2.53s/it]

S=1: used 249, skipped 0
q_hat_1 = [0.79919679 0.18875502 0.01204819]
S=2: used 249, skipped 0
q_hat_2 = [0.40160643 0.4939759  0.10441767]
S=3: used 249, skipped 0
q_hat_3 = [0.06024096 0.42971888 0.51004016]





In [10]:
def project_to_simplex(v: np.ndarray) -> np.ndarray:
    """
    Euclidean projection of v onto the probability simplex:
        {x | x >= 0, sum(x) = 1}
    Returns x with same shape as v.
    """
    v = np.asarray(v, dtype=np.float64)
    if v.ndim != 1:
        raise ValueError("project_to_simplex expects a 1D vector")

    # sort v in descending order
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, len(v) + 1) > (cssv - 1))[0]
    if len(rho) == 0:
        # fallback: all zeros -> uniform
        return np.ones_like(v) / len(v)
    rho = rho[-1]
    theta = (cssv[rho] - 1) / (rho + 1)
    w = np.maximum(v - theta, 0.0)

    # numerical safety
    s = w.sum()
    if s <= 0:
        return np.ones_like(v) / len(v)
    return w / s


In [11]:
# ============================================================
                     # 7. Acc_true 추정량 계산
#    \widehat{Acc_true}(s) = e_s^T (A_hat^T)^{-1} q_hat_s
# ============================================================

# A_hat^T 역행렬 구하기 (필요하다면 regularization 추가 가능)
A_hat = np.array(A_hat, dtype=np.float64)
A_T = A_hat.T

# 혹시 수치적으로 singular 할 경우를 대비한 작은 정규화
eps = 1e-6
A_T_reg = A_T + eps * np.eye(3)

A_T_inv = np.linalg.inv(A_T_reg)

acc_hat_per_s = {}
for s in [1, 2, 3]:
    if s not in q_hat_dict:
        continue
    q_hat_s = q_hat_dict[s]         # shape (3,)
    e_s = np.zeros(3)
    # S=1->full, 2->medium, 3->closeup 이라는 가정
    class_name = S_TO_LABEL[s]
    class_id = label2id[class_name]
    e_s[class_id] = 1.0

    #acc_s = float(e_s @ A_T_inv @ q_hat_s)
    
    p_hat = A_T_inv @ q_hat_s          # unconstrained estimate of p_s
    p_hat = project_to_simplex(p_hat)  # enforce probabilities

    acc_s = float(p_hat[class_id])     # now guaranteed in [0,1]
    acc_hat_per_s[s] = acc_s

    print(f"Estimated Acc_true(s={s}, {class_name}): {acc_s:.4f}  | p_hat={p_hat}")

# 전체 Acc_true = 1/3 * sum_s Acc_true(s)
if len(acc_hat_per_s) > 0:
    overall_acc = sum(acc_hat_per_s.values()) / len(acc_hat_per_s)
else:
    overall_acc = float("nan")

print("\n=== Final Estimated Acc_true (average over s) ===")
print(f"Acc_true_hat = {overall_acc:.4f}")


Estimated Acc_true(s=1, full): 0.8366  | p_hat=[0.83655292 0.16344708 0.        ]
Estimated Acc_true(s=2, medium): 1.0000  | p_hat=[0. 1. 0.]
Estimated Acc_true(s=3, closeup): 0.7749  | p_hat=[0.         0.22511984 0.77488016]

=== Final Estimated Acc_true (average over s) ===
Acc_true_hat = 0.8705


### train_trigger : 0.6088
S=1: used 249, skipped 0
q_hat_1 = [0.78313253 0.19277108 0.02409639]
S=2: used 249, skipped 0
q_hat_2 = [0.12851406 0.29317269 0.57831325]
S=3: used 249, skipped 0
q_hat_3 = [0.         0.00803213 0.99196787]

### train_trigger_cluster : 0.6451
S=1: used 249, skipped 0
q_hat_1 = [0.84738956 0.1124498  0.04016064]
S=2: used 249, skipped 0
q_hat_2 = [0.14859438 0.3373494  0.51405622]
S=3: used 249, skipped 0
q_hat_3 = [0.         0.00803213 0.99196787]

### train_fin : 0.7447
S=1: used 249, skipped 0
q_hat_1 = [0.93574297 0.06425703 0.        ]
S=2: used 249, skipped 0
q_hat_2 = [0.4497992  0.29718876 0.25301205]
S=3: used 249, skipped 0
q_hat_3 = [0.01204819 0.04819277 0.93975904]

### train_base : 0.8058
S=1: used 249, skipped 0
q_hat_1 = [0.82329317 0.16064257 0.01606426]
S=2: used 249, skipped 0
q_hat_2 = [0.43373494 0.35742972 0.20883534]
S=3: used 249, skipped 0
q_hat_3 = [0.04016064 0.0562249  0.90361446]

### train_base2 : 0.8705
S=1: used 249, skipped 0
q_hat_1 = [0.79919679 0.18875502 0.01204819]
S=2: used 249, skipped 0
q_hat_2 = [0.40160643 0.4939759  0.10441767]
S=3: used 249, skipped 0
q_hat_3 = [0.06024096 0.42971888 0.51004016]


In [15]:
train_trigger_mat = [[0.78313253, 0.19277108, 0.02409639],
                    [0.12851406, 0.29317269, 0.57831325],
                    [0.        ,0.00803213, 0.99196787],]
train_trigger_cluster_mat = [[0.84738956, 0.1124498,  0.04016064],
                             [0.14859438, 0.3373494,  0.51405622],
                             [0.        , 0.00803213, 0.99196787]]
train_fin_mat = [[0.93574297, 0.06425703, 0.        ],
                 [0.4497992,  0.29718876, 0.25301205],
                 [0.01204819, 0.04819277, 0.93975904]]
train_base_mat = [[0.82329317, 0.16064257, 0.01606426],
                  [0.43373494, 0.35742972, 0.20883534],
                  [0.04016064, 0.0562249,  0.90361446]]
train_base2_mat = [[0.79919679, 0.18875502, 0.01204819],
                   [0.40160643, 0.4939759,  0.10441767],
                   [0.06024096, 0.42971888, 0.51004016]]

In [16]:
train_trigger_avg = 0.5*(train_trigger_mat[0][0]+train_trigger_mat[2][2])
train_trigger_cluster_avg = 0.5*(train_trigger_cluster_mat[0][0]+train_trigger_cluster_mat[2][2])
train_fin_avg = 0.5*(train_fin_mat[0][0]+train_fin_mat[2][2])
train_base_avg = 0.5*(train_base_mat[0][0]+train_base_mat[2][2])
train_base2_avg = 0.5*(train_base2_mat[0][0]+train_base2_mat[2][2])

print(train_trigger_avg)
print(train_trigger_cluster_avg)
print(train_fin_avg)
print(train_base_avg)
print(train_base2_avg)

0.8875502
0.919678715
0.937751005
0.863453815
0.6546184749999999
