In [1]:
import numpy as np
import json
from easydict import EasyDict as edict
import yaml
from pathlib import Path
import os
import argparse
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler
from sklearn.model_selection import train_test_split

from utils.utils import seed_everything, create_dataset_from_json
from dataset.lux_dataset import LuxDataset
from models.lux_net import LuxNet
from models.lux_unet import LuxUNet
from tools.train import train_model

In [2]:
obses, samples = create_dataset_from_json('E:/Datasets/LUX_ai/DATA/top/')

  2%|▏         | 9/467 [00:00<00:05, 77.54it/s]

Loading data...


100%|██████████| 467/467 [00:11<00:00, 41.12it/s]


In [13]:
cnt=[0,0,0,0,0]
for obs_id, unit_id, label in tqdm(samples):
    cnt[label]+=1

  0%|          | 0/988664 [00:00<?, ?it/s]

In [14]:
cnt=[i*-1 for i in cnt]
print(cnt)

[-265925, -263700, -202506, -197183, -59350]


In [15]:
cnt=np.array(cnt)/100000
print(cnt)
cnt_exp=np.exp2(cnt)
sum=np.sum(cnt_exp)
print(sum)
cnt_exp=cnt_exp/sum*10
cnt_exp/min(cnt_exp)

[-2.65925 -2.637   -2.02506 -1.97183 -0.5935 ]
1.4824215775385976


array([1.        , 1.01554207, 1.5520661 , 1.61040103, 4.1865156 ])

In [6]:
obs_cnt=[0]*360
for obs_id, unit_id, label in tqdm(samples):
    obs=obses[obs_id]
    step=obs['step']
    obs_cnt[step]+=1
np.sum(obs_cnt)

  0%|          | 0/988664 [00:00<?, ?it/s]

988664

In [7]:
obs_cnt=np.array(obs_cnt)
obs_cnt=obs_cnt/np.mean(obs_cnt)*-1
obs_cnt

array([-0.16240098, -0.        , -0.14310221, -0.00436953, -0.13800442,
       -0.1390968 , -0.15621081, -0.13290663, -0.14091744, -0.1325425 ,
       -0.18388451, -0.22357444, -0.23704717, -0.25234053, -0.25124815,
       -0.26617739, -0.28948156, -0.34009532, -0.3542963 , -0.40563832,
       -0.39216559, -0.44532824, -0.44205109, -0.50868647, -0.51159949,
       -0.57350121, -0.59498475, -0.65397344, -0.70349482, -0.74536951,
       -0.61719654, -0.70968499, -0.41947517, -0.30113365, -0.47809974,
       -0.5826044 , -0.43003488, -0.31278574, -0.29931301, -0.17259655,
       -0.74682602, -0.63248991, -1.17212723, -0.50759409, -1.01992183,
       -0.60954986, -0.94891692, -0.843684  , -1.01191102, -0.85861324,
       -1.06689431, -0.91614542, -1.13025254, -1.08364419, -1.21036065,
       -1.2070835 , -1.24895819, -1.26862109, -1.32433263, -1.37057686,
       -1.38878325, -1.43721224, -1.43830462, -1.5114943 , -1.50275523,
       -1.57740142, -1.53661911, -1.61199356, -1.59597194, -1.61

In [8]:
np.exp2(obs_cnt)/np.mean(np.exp2(obs_cnt))

array([1.65561897, 1.85288284, 1.67791479, 1.84727945, 1.68385422,
       1.68257972, 1.66273799, 1.68981467, 1.68045769, 1.69024122,
       1.63114738, 1.5868846 , 1.57213433, 1.55555686, 1.55673514,
       1.54070886, 1.51602138, 1.46375727, 1.44941965, 1.39874536,
       1.41186884, 1.36078903, 1.36388364, 1.30232109, 1.29969415,
       1.24510762, 1.22670378, 1.17755808, 1.13782355, 1.10527254,
       1.20796204, 1.13295195, 1.38539415, 1.50382637, 1.3302265 ,
       1.23727591, 1.37529088, 1.49172947, 1.50572536, 1.6439599 ,
       1.10415724, 1.19522461, 0.82224739, 1.30330756, 0.91373633,
       1.21438157, 0.95983261, 1.03246127, 0.91882412, 1.02183228,
       0.88446524, 0.98188518, 0.84646312, 0.87425586, 0.8007429 ,
       0.8025639 , 0.77960404, 0.76905068, 0.73991888, 0.71657757,
       0.7075914 , 0.68423292, 0.68371503, 0.64989446, 0.65384312,
       0.62087306, 0.63867441, 0.60616317, 0.61293234, 0.60555151,
       0.78394496, 0.77960404, 1.06125958, 1.2316676 , 1.03376