# 0. Intro

# 1. Preparation

In [1]:
from common import *
from utils import *
from mydataset import *
from my_collate_fn import my_collate_fn_3

from Config.config_MDN import DefaultConfig

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


# 2. 执行validation
- 注意使用的validation function

In [9]:
MODEL_LIST = ["GT1(MDN)","GT2(MDN)","GT3","EMD"]
class Conv_block_4(nn.Module):

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def __init__(self, ch_out=1,kernel_size=9, stride=3, init_weight=True) -> None:
        super().__init__()

        self.kernel_size = (1,kernel_size)
        self.stride = (1,stride)
        self.ln_in = int((300-kernel_size)/stride+1)

        self.ac_func = nn.Softplus()

        self.conv = nn.Conv2d(in_channels=1, out_channels=ch_out, kernel_size=self.kernel_size, stride=self.stride, padding=0, dilation=(1,1))

        self.BN_aff2 = nn.BatchNorm1d(num_features=self.ln_in,affine=True)      # works better

        if init_weight:
            self._initialize_weights()

    def forward(self, x):
        # Conv=>BN=>AC
        x = self.conv(x)
        # print(x.shape)
        # 方法一：
        # x = torch.squeeze(x,dim=2)
        # x = self.ac_func(self.BN_aff1(x))

        # 方法二：works better
        x = torch.flatten(x,start_dim=1)
        x = self.ac_func(self.BN_aff2(x))
        # x = self.ac_func(self.BN_aff1(x))
        return x

class Conv_1_4(nn.Module):
    # code->generate->override methods
    def __init__(self, n_gaussians, ch_out=1, kernel_size=9, stride=3) -> None:
        super().__init__()

        self.ln_in = int((300-kernel_size)/stride+1)

        self.BN1 = nn.BatchNorm1d(num_features=1,affine=True)
        # self.IN1 = nn.InstanceNorm1d(num_features=3,affine=True)
        self.layer_pi = Conv_block_4(ch_out=1,kernel_size=kernel_size,stride=stride)
        self.layer_scale = Conv_block_4(ch_out=1,kernel_size=kernel_size,stride=stride)
        self.layer_shape = Conv_block_4(ch_out=1,kernel_size=kernel_size,stride=stride)

        self.ac_func = nn.Softplus()

        self.z_pi = nn.Sequential(
            nn.Linear(self.ln_in, n_gaussians),
            nn.Softmax(dim=1)           # dim=0是B, dim=1才是feature
        )
        self.z_scale = nn.Linear(self.ln_in, n_gaussians)
        self.z_shape = nn.Linear(self.ln_in, n_gaussians)

    def forward(self, x):

        x = self.BN1(x)
        # x = self.IN1(x)
        x = torch.unsqueeze(x,dim=1)                     # torch.Size([B, 1, 3, 300])

        x_pi = self.layer_pi(x)
        x_scale = self.layer_scale(x)
        x_shape = self.layer_shape(x)

        pi = self.z_pi(x_pi)
        scale = torch.exp(self.z_scale(x_scale))
        scale = torch.clamp(scale,1e-4)
        shape = torch.exp(self.z_shape(x_shape))
        shape = torch.clamp(shape,1e-4)

        return pi,scale,shape


In [10]:
seed = 62
setup_seed(seed)

In [15]:
for MODEL_NAME in MODEL_LIST:
    opt = DefaultConfig(MODEL_NAME=MODEL_NAME)
    mlp = Conv_1_4(opt.N_gaussians).to(device)

    model_path = get_MDN_save_path(opt.ARTIFICIAL, seed, opt.net_root_path, opt.noise_pct, MODEL_NAME)
    mlp, hyperparameters = load_checkpoint(model_path, mlp)

    dataset = myDataset(opt.train_path, opt.target_path_metric, opt.target_path_loss, opt.data_key_path, opt.NLL_metric_path)
    shuffled_indices = save_data_idx(dataset, opt.arr_flag)
    train_idx, val_idx, test_idx = get_data_idx(shuffled_indices, opt.train_pct, opt.vali_pct,opt.SET_VAL)

    if MODEL_NAME == "GT1(MDN)":
        INPUT_LIST = [1]
    elif MODEL_NAME == "GT2(MDN)":
        INPUT_LIST=[2]
    elif MODEL_NAME == "GT3":
        INPUT_LIST=[3]
    elif MODEL_NAME == "EMD":
        INPUT_LIST=[4]
    else:
        assert f"Wrong Model Name! The Name Has to be One of {MODEL_LIST}"

    my_collate_fn = functools.partial(my_collate_fn_3, INPUT_LIST=INPUT_LIST)
    _,_,test_loader = get_data_loader(dataset, opt.batch_size, train_idx, val_idx, test_idx, my_collate_fn)

    with torch.no_grad():
        total_test_metric, GT_metric = validate_KL(mlp, test_loader, opt.N_gaussians, opt.MIN_LOSS, device)

    print(f"========== seed = {seed} ==========")
    print(f"========== MODEL_NAME = {MODEL_NAME} ==========")

    print(f"========== total_test_metric: {total_test_metric} ==========")
    print(f"========== GT_metric: {GT_metric} ==========")


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 98])

In [14]:
for epoch in range(0):
    print(f"aaa")