In [106]:
import logging
from logging import handlers
import os
import numpy as np
# import wandb

In [118]:
def pretty(d, indent=0):
    for key, value in d.items():
        print(f"{key}: {value}")

In [119]:
class MetricsLogger:
    def __init__(self, wandb_config):
        self.wandb_config = wandb_config

        self.metrics_scores = [[[[0.0 for _ in range(5)] for _ in range(6)] for _ in range(6)] for _ in range(2)]
        self.metrics_counter = [[[[0 for _ in range(5)] for _ in range(6)] for _ in range(6)] for _ in range(2)]

        self.gender_map = {'m': 0, 'f': 1}
        self.metrics_map = {0: 'psnr', 1: 'ssim', 2: 'mssim', 3: 'lpips', 4: 'fid'}
        self.noise_map = {0: 'clean', 1: 'noise1', 2: 'noise2', 3: 'noise3', 4: 'noise4', 5: 'noise5'}

    def set_idx(self, labels):
        self.gender_idx = self.gender_map[labels['gender']]
        self.age_idx = labels['age'] - 1
        self.noise_idx = labels['noise'] - 1

    def compute_metrics_scores(self):
        avg_scores = [[[[0.0 for _ in range(5)] for _ in range(6)] for _ in range(6)] for _ in range(2)]
        for i in range(2):
            for j in range(6):
                for k in range(6):
                    for v in range(5):
                        if self.metrics_counter[i][j][k][v] != 0:
                            avg_scores[i][j][k][v] = self.metrics_scores[i][j][k][v] / self.metrics_counter[i][j][k][v]
        return avg_scores

    def append_metrics(self, batch_size, labels, metrics_scores):
        for i in range(batch_size):
            self.set_idx(labels[i])
            # print(f"gender idx : {self.gender_idx}")
            # print(f"age idx : {self.age_idx}")
            # print(f"noise idx : {self.noise_idx}")
            # print(f"mterics scores : {metrics_scores[i]}")
            for idx, score in enumerate(metrics_scores[i]):
                # print(idx)
                # print(score)
                # print(np.shape(self.metrics_scores))
                # print(self.metrics_scores[self.gender_idx][self.age_idx][self.noise_idx][idx])
                self.metrics_scores[self.gender_idx][self.age_idx][self.noise_idx][idx] += score
                self.metrics_counter[self.gender_idx][self.age_idx][self.noise_idx][idx] += 1

    def init_metrics_scores(self):
        self.metrics_scores = [[[[[0.0 for _ in range(5)] for _ in range(6)] for _ in range(6)] for _ in range(2)]]
        self.metrics_counter = [[[[[0 for _ in range(5)] for _ in range(6)] for _ in range(6)] for _ in range(2)]]

    def wandb_logging(self, epoch):
        avg_scores = self.compute_metrics_scores()
        wandb_log = {}
        for i in range(2):
            for j in range(6):
                for k in range(6):
                    for v in range(5):
                        if self.metrics_counter[i][j][k][v] != 0:
                            name = f"{list(self.gender_map.keys())[list(self.gender_map.values()).index(i)]}_{(j+1)*10}_{self.noise_map[k]}_{self.metrics_map[v]}"
                            value = avg_scores[i][j][k][v]
                            wandb_log[name] = value
        pretty(wandb_log)
        # wandb.log(wandb_log)
        self.init_metrics_scores()

In [120]:
import random
wandb_config = None
new_logger = MetricsLogger(wandb_config)
random_age, random_noise = [], []
for i in range(10):
    random_age.append(random.randint(1, 6))
    random_noise.append(random.randint(1, 6))
data = [{'gender': 'f', 'age': i, 'noise': j} for i, j in zip(random_age, random_noise)]

# print(data)

for row in data:
    print(row)


{'gender': 'f', 'age': 5, 'noise': 5}
{'gender': 'f', 'age': 3, 'noise': 3}
{'gender': 'f', 'age': 4, 'noise': 2}
{'gender': 'f', 'age': 2, 'noise': 4}
{'gender': 'f', 'age': 3, 'noise': 1}
{'gender': 'f', 'age': 2, 'noise': 5}
{'gender': 'f', 'age': 2, 'noise': 2}
{'gender': 'f', 'age': 3, 'noise': 4}
{'gender': 'f', 'age': 5, 'noise': 3}
{'gender': 'f', 'age': 5, 'noise': 2}


In [121]:
rows, cols = 10, 5
metrics_scores = [[random.uniform(0, 50) for _ in range(cols)] for _ in range(rows)]

# print(metrics_scores)

for row in metrics_scores:
    print(row)

[46.89977925114239, 44.57271738438848, 7.716705752368835, 29.21240533200785, 9.323768189698406]
[29.152282830412407, 12.024466656785549, 48.09339765789273, 27.39663405586436, 26.079211157724725]
[39.34440977135893, 8.509077739538395, 8.404255887173479, 16.203554446316065, 23.906856918352794]
[43.45438877347952, 9.269233949617234, 36.80025195708633, 38.110003600814494, 3.959148129531026]
[21.191750330994985, 48.81254189376028, 6.1223524652488255, 27.930990228314762, 25.741414744895092]
[36.19104522805784, 26.791858713914813, 1.1524965382050667, 41.83750753106294, 8.01387612038727]
[17.645881251504104, 7.0088928139286635, 8.75402168263102, 11.731568964763905, 33.53026165154181]
[3.908747856600442, 49.6373119480884, 10.278073637563494, 24.750433737600886, 14.68299186500031]
[21.187407306809114, 41.04792997002819, 29.436544408043837, 34.133326691378755, 37.08087890781607]
[19.235924913043085, 42.75227375918209, 45.71744828673657, 28.12037297450059, 6.309455494198801]


In [122]:
print(metrics_scores)

[[46.89977925114239, 44.57271738438848, 7.716705752368835, 29.21240533200785, 9.323768189698406], [29.152282830412407, 12.024466656785549, 48.09339765789273, 27.39663405586436, 26.079211157724725], [39.34440977135893, 8.509077739538395, 8.404255887173479, 16.203554446316065, 23.906856918352794], [43.45438877347952, 9.269233949617234, 36.80025195708633, 38.110003600814494, 3.959148129531026], [21.191750330994985, 48.81254189376028, 6.1223524652488255, 27.930990228314762, 25.741414744895092], [36.19104522805784, 26.791858713914813, 1.1524965382050667, 41.83750753106294, 8.01387612038727], [17.645881251504104, 7.0088928139286635, 8.75402168263102, 11.731568964763905, 33.53026165154181], [3.908747856600442, 49.6373119480884, 10.278073637563494, 24.750433737600886, 14.68299186500031], [21.187407306809114, 41.04792997002819, 29.436544408043837, 34.133326691378755, 37.08087890781607], [19.235924913043085, 42.75227375918209, 45.71744828673657, 28.12037297450059, 6.309455494198801]]


In [123]:
new_logger.append_metrics(10, data, metrics_scores)
new_logger.wandb_logging(epoch=1)

f_20_noise1_psnr: 17.645881251504104
f_20_noise1_ssim: 7.0088928139286635
f_20_noise1_mssim: 8.75402168263102
f_20_noise1_lpips: 11.731568964763905
f_20_noise1_fid: 33.53026165154181
f_20_noise3_psnr: 43.45438877347952
f_20_noise3_ssim: 9.269233949617234
f_20_noise3_mssim: 36.80025195708633
f_20_noise3_lpips: 38.110003600814494
f_20_noise3_fid: 3.959148129531026
f_20_noise4_psnr: 36.19104522805784
f_20_noise4_ssim: 26.791858713914813
f_20_noise4_mssim: 1.1524965382050667
f_20_noise4_lpips: 41.83750753106294
f_20_noise4_fid: 8.01387612038727
f_30_clean_psnr: 21.191750330994985
f_30_clean_ssim: 48.81254189376028
f_30_clean_mssim: 6.1223524652488255
f_30_clean_lpips: 27.930990228314762
f_30_clean_fid: 25.741414744895092
f_30_noise2_psnr: 29.152282830412407
f_30_noise2_ssim: 12.024466656785549
f_30_noise2_mssim: 48.09339765789273
f_30_noise2_lpips: 27.39663405586436
f_30_noise2_fid: 26.079211157724725
f_30_noise3_psnr: 3.908747856600442
f_30_noise3_ssim: 49.6373119480884
f_30_noise3_mssim: