In [1]:
%%capture
import pandas as pd
import numpy as np
import os
import yaml
from easydict import EasyDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import normalized_mutual_info_score as NMI

from z_common import config_to_execute
from clustering import run_clustering
from dataset import get_datasets
from simclr import SimCLR
from log_functions import *

from models.baseline_encoder import Encoder
from models.alexnet_simclr import AlexSimCLR
from models.resnet_simclr import ResNetSimCLR
from loss.nt_xent import NTXentLoss
from functions import *

In [2]:
import pandas as pd
import numpy as np
import io
from glob import glob
pd.set_option('display.max_rows', 100)

# titles = ['dslr_webcam', 'amazon_dslr', 'webcam_amazon']
# file_domains = ['dw', 'ad', 'wa']
# log_files = ['../log_dw.txt', '../log_ad.txt', '../log_wa.txt']

EXEC_REP_NUM = 2
CLUST_REP_NUM = 3
title = 'svhn_synth_mnist_mnistm'
file_domain = 'svhn_synth_mnist_mnistm'
log_file = f'./record/Digit/{title}_logs.txt'
prompts_files_t = [glob(f"./record/Digit/CUDA{i}/svhn_synth_mnist_mnistm_*/prompts.log") for i in range(10)]

# 各ディレクトリのprompts.logをcuda_dir毎に取得し, 2次元リストを作成
prompts_files = []
for cudaf in prompts_files_t:
    if len(cudaf) != 0:
        prompts_files.append(cudaf)


""" 各ファイルのprompts.logを1つのログファイルを作成 """
output_texts = []
for cuda_porompts_files in prompts_files:
    log_texts = []
    for prompts in cuda_porompts_files:
        with open(prompts, 'r') as f:
            logs = f.read()
        split_texts = logs.split('\n')
        split_texts = [f"{ft}\n" for ft in split_texts]
        title_text = split_texts[:14]

        log_text = [
            line for line in split_texts 
            if 'Epoch:' in line
            or 'nmi:' in line
            or 'nmi class:' in line
            or 'domain_accuracy:' in line
        ]
        log_text.insert(0, '\n==========================================\n')
        log_texts.append(log_text)

    output_text = sum(log_texts, [])
    output_texts.append(np.concatenate([title_text, output_text]))

# ログファイル書き込み
output_texts = np.concatenate(output_texts)
with open(log_file, 'w', newline='\n') as f:
    f.writelines(output_texts)


In [3]:
"""
    dft: ログから得た値(縦持ち)
    df: ログから得た値(横持ち). 各実行,各クラスタリングそれぞれの値を全て保持.
            レコード数: len(aug_ilst) * EXEC_REP_NUM * CLUST_REP_NUM
    dfg: クラスタリングの平均値をまとめた.
            レコード数: len(aug_ilst) * EXEC_REP_NUM
    dfg_avg: その平均値をまとめた
            レコード数: len(aug_ilst)
    csvに書き込んでいく.
"""
dfs = {}
dfgs = {}
dfg_avgs = {}


""" ログファイルからnmi等を記した行のみを取得し, DataFrameを作る. """
with open(log_file, 'r') as f:
    text = f.read()
split_text = text.split('\n')
log_text = [line for line in split_text if 'nmi:' in line or 'nmi class:' in line or 'domain_accuracy:' in line]
split_logs = np.array([line.split(':') for line in log_text])

dft = pd.DataFrame(split_logs, columns=['title', 'result'])

augs_list = [
    'none',
    'jigsaw_const_phase',
    'jigsaw',
    # 'mask',
    # 'jigsaw_mask',
    # 'mask_const_phase',
    'jigsaw_mask_const_phase',
]
augs_rep = sum([[aug for _ in range(EXEC_REP_NUM * CLUST_REP_NUM)] for aug in augs_list], [])
df = pd.DataFrame(augs_rep, columns=['augs'])
df['exec_number'] = sum([[i//EXEC_REP_NUM for i in range(EXEC_REP_NUM * CLUST_REP_NUM)] for _ in range(len(augs_list))], [])
df['nmi_domain'] = dft[dft['title']=='nmi'].result.values
df['nmi_class'] = dft[dft['title']=='nmi class'].result.values
df['domain_accuracy'] = dft[dft['title']=='domain_accuracy'].result.values
df[['nmi_domain', 'nmi_class', 'domain_accuracy']] = df[['nmi_domain', 'nmi_class', 'domain_accuracy']].astype('float').round(5)

dfg = df.groupby(['augs', 'exec_number']).mean().reset_index()
dfg_avg = dfg.drop('exec_number', axis=1).groupby('augs').mean().reset_index()


""" csv書き込み """
df.to_csv(f'./record/Digit/{title}_result_df.csv', header=True, index=False)
dfg.to_csv(f'./record/Digit/{title}_result_dfg.csv', header=True, index=False)
dfg_avg.to_csv(f'./record/Digit/{title}_result_dfg_avg.csv', header=True, index=False)