Skip to content

Did the training result correct?  #13

@chufall

Description

@chufall

Hi
thank you for your great open source work .
I have write the training code base on the disscussion in the #4

The code is as the following:

 #called in the predict()    
 def _forward(self, frames_in, frames_gt):

        B, T_in, c, h, w = frames_in.shape
        T_out = frames_gt.shape[1]

        # 确定性预测,调用simVP
        device = frames_in.device
        backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt,
                                                                   compute_loss=True)

        #归一化
        frames_in = self.normalize(frames_in)
        frames_gt = self.normalize(frames_gt)
        backbone_output = self.normalize(backbone_output)

        #计算残差r = y - mu 和 h
        residual = frames_gt - backbone_output  #eq.7
        global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1))  #eq.14

        #进入
        pre_frag = frames_in
        pre_mu = None
        pred_ress = []
        diff_loss = 0.
        t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()  #随机在[0,T]之间采样一个batch的时间步

        #以segment进行循环
        for frag_idx in range(T_out // T_in):

            #取当前segment的mu和r
            mu = backbone_output[:, frag_idx * T_in: (frag_idx + 1) * T_in]   # ^mu_j
            res = residual[:, frag_idx * T_in: (frag_idx + 1) * T_in]         # ^s_j

            # s_j-1  由于j=0时,s_-1没有值,用frame_in代替
            cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)

            # 用 s_j-1,h,t 来进行预测
            _, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
                                                 idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
            diff_loss += noise_loss

            pre_frag = frames_gt[:, frag_idx * T_in: (frag_idx + 1) * T_in]
            pre_mu = mu

        diff_loss /= (T_out // T_in)
        loss = (1 - self.loss_weight_factor) * backbone_loss + self.loss_weight_factor * diff_loss

        return loss

    @autocast(enabled = False)
    def q_sample(self, x_start, t, noise = None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
        b, _, c, h, w = x_start.shape

        noise = default(noise, lambda: torch.randn_like(x_start))

        # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
        offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)

        if offset_noise_strength > 0.:
            offset_noise = torch.randn(x_start.shape[:2], device = self.device)
            noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')

        # noise sample
        x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating

        model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)

        if self.objective == 'pred_noise':
            target = noise
        elif self.objective == 'pred_x0':
            target = x_start
        elif self.objective == 'pred_v':
            v = self.predict_v(x_start, t, noise)
            target = v
        else:
            raise ValueError(f'unknown objective {self.objective}')

        loss = F.mse_loss(model_out, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b', 'mean')

        loss = loss * extract(self.loss_weight, t, loss.shape)
        return model_out, loss.mean()

After the 200K iterations on the SHANGHAI dataset of 5:20, I go the following results:

01/05/2025 12:52:33 - INFO - root - ****************************** < Evaluation Results: > ******************************
01/05/2025 12:52:33 - INFO - root - Total 850 samples with 20 seq_len.
01/05/2025 12:52:33 - INFO - root - ******************************************************************************************
01/05/2025 12:52:33 - INFO - root - ====================Threshold: 20 with melthod 1====================
01/05/2025 12:52:33 - INFO - root - <CSI> : 0.565800888469066; [0.82626148 0.77370377 0.73132469 0.6950084  0.65983306 0.63396121
 0.61126816 0.5897937  0.57044672 0.54923247 0.53140962 0.5152367
 0.5000041  0.48699576 0.47042913 0.4575324  0.44525369 0.43425388
 0.42381216 0.41025667]
01/05/2025 12:52:33 - INFO - root - <FAR> : 0.24224349243045568; [0.09140799 0.11713036 0.13861137 0.1577845  0.17509894 0.19173357
 0.20543915 0.21761054 0.23058481 0.2454496  0.26172361 0.27388339
 0.28266903 0.29216861 0.30459843 0.31655573 0.32447319 0.33138028
 0.3382836  0.34828315]
01/05/2025 12:52:33 - INFO - root - <POD> : 0.6789452268045424; [0.90117156 0.86220701 0.82886745 0.79904948 0.76730174 0.74617523
 0.7260127  0.70553094 0.68807796 0.66870393 0.65475824 0.63952266
 0.62269332 0.60951819 0.59252213 0.58056485 0.56638467 0.55334772
 0.54103347 0.52546129]
01/05/2025 12:52:33 - INFO - root - <HSS> : 0.6890801163422913; [0.89509785 0.85943835 0.82916844 0.80209205 0.77483725 0.7540032
 0.7352728  0.71710628 0.70032196 0.68142543 0.66503864 0.65001531
 0.63570107 0.62322894 0.60698766 0.59403597 0.58174211 0.57058332
 0.55987288 0.54563282]
01/05/2025 12:52:33 - INFO - root - < CSI_POOL 4x4 > : 0.6325025462842695; CSI_POOL 16x16: 0.7372783605343171
01/05/2025 12:52:33 - INFO - root - ====================Threshold: 30 with melthod 1====================
01/05/2025 12:52:33 - INFO - root - <CSI> : 0.47518983292786804; [0.77611761 0.70332587 0.65081106 0.60978278 0.57104518 0.54241503
 0.51670373 0.49470957 0.4747346  0.45419573 0.4378177  0.42049936
 0.40426451 0.3901161  0.37451299 0.36072241 0.3478775  0.33526671
 0.32560534 0.31327288]
01/05/2025 12:52:33 - INFO - root - <FAR> : 0.2974831751010489; [0.11926539 0.15896182 0.18961822 0.21324625 0.23454383 0.25044013
 0.26632791 0.27872468 0.29262943 0.30590132 0.31843402 0.33102171
 0.339345   0.34945696 0.36111365 0.37224913 0.38123591 0.38774515
 0.39555952 0.40384346]
01/05/2025 12:52:33 - INFO - root - <POD> : 0.5818251720470691; [0.86726605 0.81115542 0.76772015 0.73052289 0.69215437 0.66247519
 0.63599545 0.61163767 0.59075331 0.56786671 0.5504043  0.53098057
 0.51020965 0.4935439  0.47507953 0.45887849 0.44278149 0.42564184
 0.41377368 0.39766279]
01/05/2025 12:52:33 - INFO - root - <HSS> : 0.6179291139958538; [0.86733634 0.81679291 0.7776323  0.74534312 0.71336996 0.68876311
 0.66590397 0.6457995  0.62703107 0.60721697 0.59098745 0.57350071
 0.5568494  0.54201544 0.52525985 0.51014044 0.49587965 0.48172341
 0.4706891  0.45634757]
01/05/2025 12:52:33 - INFO - root - < CSI_POOL 4x4 > : 0.5486872387090821; CSI_POOL 16x16: 0.650950245431883
01/05/2025 12:52:33 - INFO - root - ====================Threshold: 35 with melthod 1====================
01/05/2025 12:52:33 - INFO - root - <CSI> : 0.39606485937913605; [0.73081963 0.63969667 0.57626237 0.52955149 0.48813026 0.45701573
 0.42922055 0.40402249 0.38486699 0.36413771 0.35018367 0.3348004
 0.32014697 0.30861241 0.29448463 0.2841865  0.27228398 0.26026559
 0.25226034 0.2403488 ]
01/05/2025 12:52:33 - INFO - root - <FAR> : 0.3634073054154245; [0.1444574  0.20044106 0.24029416 0.27191867 0.29757277 0.31559764
 0.33603606 0.35141051 0.36686866 0.38152131 0.39263914 0.40685528
 0.41472505 0.42344394 0.43519667 0.44320855 0.45207288 0.4567756
 0.46447454 0.47263621]
01/05/2025 12:52:33 - INFO - root - <POD> : 0.4973780625681573; [0.83369612 0.76187494 0.70471085 0.66010192 0.61538519 0.57904561
 0.5483362  0.51724965 0.49533172 0.46962838 0.45265733 0.4346078
 0.41408439 0.39906242 0.38091826 0.36726745 0.35117589 0.33318127
 0.32291046 0.30633542]
01/05/2025 12:52:33 - INFO - root - <HSS> : 0.5443081733634114; [0.83988494 0.77385992 0.72345734 0.68371708 0.64648701 0.61720833
 0.58996796 0.56440731 0.54435543 0.52208137 0.50671081 0.48939319
 0.4726161  0.45916685 0.44233217 0.42988669 0.41527263 0.40032962
 0.39016709 0.37486163]
01/05/2025 12:52:33 - INFO - root - < CSI_POOL 4x4 > : 0.48441086275194206; CSI_POOL 16x16: 0.5865675293493181
01/05/2025 12:52:33 - INFO - root - ====================Threshold: 40 with melthod 1====================
01/05/2025 12:52:33 - INFO - root - <CSI> : 0.285276937459521; [0.6645734  0.54933477 0.47211479 0.41623437 0.37060638 0.33353102
 0.30549998 0.27927691 0.25895973 0.24064859 0.22956794 0.21639686
 0.20275106 0.1936537  0.18488311 0.1772157  0.16702052 0.15529105
 0.14945041 0.13852845]
01/05/2025 12:52:33 - INFO - root - <FAR> : 0.4689316897550597; [0.18273019 0.26169754 0.31643544 0.36079712 0.39470287 0.41781393
 0.44302741 0.46294865 0.48355582 0.49967694 0.51101887 0.52326946
 0.53594818 0.54682181 0.55620708 0.56156442 0.56961588 0.57769891
 0.58252461 0.59057862]
01/05/2025 12:52:33 - INFO - root - <POD> : 0.36574409829472676; [0.7805557  0.68216308 0.6041532  0.54405659 0.48871093 0.43848919
 0.40356718 0.36782879 0.34184685 0.31678369 0.30202902 0.28380725
 0.26474463 0.2527027  0.24064379 0.22925242 0.21441824 0.19717867
 0.18882775 0.17312233]
01/05/2025 12:52:33 - INFO - root - <HSS> : 0.4211213455260009; [0.7956374  0.70508913 0.63656133 0.58233788 0.53489373 0.49405709
 0.461629   0.43007983 0.40470491 0.38120264 0.3666254  0.34901006
 0.33035745 0.3176708  0.30529602 0.29437251 0.27963301 0.26240511
 0.2536806  0.23718301]
01/05/2025 12:52:33 - INFO - root - < CSI_POOL 4x4 > : 0.40419628442688127; CSI_POOL 16x16: 0.51684827582216
01/05/2025 12:52:33 - INFO - root - ********************Overall Avg Metrics on Thresholds [20, 30, 35, 40]********************
01/05/2025 12:52:33 - INFO - root - [ avg_csi ] : 0.43058312955889777; [ avg_far ] : 0.3430164156754972; [ avg_pod ] : 0.5309731399286239; [ avg_hss] : 0.5681096873068894
01/05/2025 12:52:33 - INFO - root - [ avg_csi_pool 4x4 ] : 0.5174492330430437; [ avg_csi_pool 16x16 ]: 0.6229111027844196

The sample output frames are as the following:
image

Does the result make sense? May I ask you to take a look and give me some suggestion!
Thanks a lot!

Sincerely,
QC

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions