# Train

In [None]:
import torch
from sklearn.metrics import f1_score
from util import train_std
from dataset import PhaseIdentification as dataset
from vit import XViT as train_model
import os.path as osp
import os

def metric_func(data,output):
    pred = torch.sigmoid(output).detach().cpu().numpy()
    score = f1_score(data[-1].cpu().numpy(),(pred>0.5).astype(int),average='macro',zero_division=0)
    return score

criterion = torch.nn.BCEWithLogitsLoss(reduction='none')

def loss_func(data,output):
    loss = criterion(output,data[-1].to(output.device))
    loss = torch.mean(loss,0)
    return loss,torch.mean(loss)
    
dataname='zeolite'
data_config = dict(dataname=dataname,num=136,
                batch_size=512)
model_config = dict(stas_fn=f'save/data/{dataname}/stat_signal.pt',mlp_drop=0.5,img_size = 4501,in_chans = 1,patch_size=64,num_classes=data_config['num'],depth=6,drop_rate=0.0,num_heads=32,embed_dim=256)

train_config = dict(metric_func=metric_func,
                    loss_func=loss_func,device='cuda:0',epochs=400)
optim_para = dict(lr=0.001,weight_decay=0.0)
dataroot = 'save/data/{dataname}/'.format(**data_config)
modeldir = 'save/model/{dataname}/XViT/mdr{mlp_drop}_dr{drop_rate}_atdp0.0_dep{depth}_ps{patch_size}_nh{num_heads}_hd{embed_dim}_lr{lr}_wd{weight_decay}'.format(**data_config,**model_config,**optim_para)
logdir = modeldir.replace('model','log')
os.makedirs(modeldir,exist_ok=True)
os.makedirs(logdir,exist_ok=True)

# p = int(data_config['datasize']*100)
modelfns = [osp.join(modeldir,f'{suffix}.pth') for suffix in ['a','b','c']]

logfn =[f.replace('model','log').replace('pth','npy') for f in modelfns]
modelfns = [z for f,z in zip(logfn,modelfns) if not osp.exists(f)]
print(modelfns)
train_std(modelfns,model_config,dataroot,dataset,train_model,data_config,train_config,optim_para,torch.optim.AdamW)

# Evaluate

In [None]:
from util import evaluate_std
from dataset import PhaseIdentification as dataset
from vit import XViT as eval_model
import os.path as osp
import torch.nn.functional as F
from sklearn.metrics import f1_score
import torch
import numpy as np

def metric_func(data,output):
    pred = torch.sigmoid(output).detach().cpu().numpy()
    score = f1_score(data[-1].cpu().numpy(),(pred>0.5).astype(int),average='macro',zero_division=0)
    return score

dataname = 'zeolite'
num = 136

data_config = dict(dataname=dataname,num=num,
                batch_size=512)
model_config = dict(stas_fn=f'save/data/{dataname}/stat_signal.pt',mlp_drop=0.5,img_size = 4501,in_chans = 1,patch_size=64,num_classes=data_config['num'],depth=6,drop_rate=0.0,num_heads=32,embed_dim=256)

optim_para = dict(lr=1e-3,weight_decay=0.0)

dataroot = 'save/data/{dataname}/'.format(**data_config)

modeldir = 'save/model/{dataname}/XViT/mdr{mlp_drop}_dr{drop_rate}_atdp0.0_dep{depth}_ps{patch_size}_nh{num_heads}_hd{embed_dim}_lr{lr}_wd{weight_decay}'.format(**data_config,**model_config,**optim_para)


modelfns = [osp.join(modeldir,f'{suffix}.pth') for suffix in ['a']]
device = 'cuda:1'
    
for mode in ['test']:
    scores = evaluate_std(modelfns,model_config,dataroot,dataset,eval_model,data_config,metric_func,device,mode)
    print('%s : %.2fÂ±%.2f'%(dataname,np.mean(scores)*100,np.std(scores)*100))