In [1]:
# Data Loader
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import pytorch_lightning as pl
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
import json
from CustomDataset import * 
import sys
sys.path.append('../')
from config import *
from file_helper import *

def collate_fn(batch):
    return tuple(zip(*batch))
#=================================
#             Augmentation
#=================================

def gauss_noise_tensor(img):
    rand = torch.rand(1)[0]
    if rand < 0.5 and Horizon_AUG:
        sigma = rand *0.125
        out = img + sigma * torch.randn_like(img)
        return out
    return img

def blank(img):    
    return img

class CustomDataModule(pl.LightningDataModule):
    def __init__(self , train_dir , test_dir , batch_size = 2, num_workers = 0 , img_size=[IMG_WIDTH, IMG_HEIGHT] , use_aug = True ,padding_count = 24 ,c =0.1 ):
        super().__init__()
        self.train_dir = train_dir
        self.test_dir = test_dir
        
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.img_size = img_size      
        self.use_aug = use_aug
        self.padding_count  = padding_count
        self.c = c
        
        pass

    def prepare_data(self) -> None:
        # Download dataset
        pass

    def setup(self, stage):
        # Create dataset...          
                
        self.entire_dataset = CustomDataset(self.train_dir  , use_aug= self.use_aug , padding_count= self.padding_count , c=self.c)
        self.train_ds , self.val_ds = random_split(self.entire_dataset , [0.9, 0.1])        
        self.test_ds = CustomDataset(self.test_dir  , use_aug= False)
        
        pass

    # ToDo: Reture Dataloader...
    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return DataLoader(self.train_ds , batch_size= self.batch_size , num_workers= self.num_workers , shuffle=True)
    
    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.val_ds , batch_size= self.batch_size , num_workers= self.num_workers , shuffle=False)
    
    def test_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.test_ds , batch_size= self.batch_size , num_workers= self.num_workers , shuffle=False)

    pass


# Test
dm = CustomDataModule ( train_dir= f"../anno/test_visiable_10_no_cross.json" ,
                       test_dir= f"../anno/test_visiable_10_no_cross.json" , padding_count=256
                       )

  from .autonotebook import tqdm as notebook_tqdm


d:\Projects\Layout\NTHU_CGV_Layout_exp\ours
d:\Projects\Layout\NTHU_CGV_Layout_exp\Horizon_and_SAM\Horizon


In [2]:
from torch import Tensor
def unpad_data( x :[Tensor] ) :
	non_zero_indices = torch.nonzero(x)
	#print(non_zero_indices)
	# Get the non-zero values
	non_zero_values = x[non_zero_indices[:,0], non_zero_indices[:,1]]

	unique = torch.unique(non_zero_indices[:,0] ,return_counts=True)
	#print("unique" , unique)
	# Print the result
	#print(non_zero_values)
	non_zero_values = torch.split(non_zero_values , tuple(unique[1]))
	#print("split non_zero_values" , non_zero_values)
	return non_zero_values

In [12]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import Any
import pytorch_lightning as pl
from config import *
import torchvision.models as models
from torchvision.ops import MLP
import math
from torch import Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
from VerticalCompressionNet import * 

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: 256, dropout: float = 0.0, max_len: int = 1024):
        super().__init__()        
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)        
        print("position", position.shape)
        div_term = torch.exp(torch.arange(0, d_model) * (-math.log(10000.0) / d_model))
        #div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        print("div_term", div_term.shape)
        #pe = torch.zeros(max_len, 1, d_model)  # []
        pe = torch.zeros(max_len, d_model)  # [ 1024 , 256 ]
        print("pe", pe.shape)
        #pe[:, 0, 0::2] = torch.sin(position * div_term)
        #pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe[: ,:] = torch.sin(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:    
        """
        Arguments:
            #x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
            x: Tensor, shape ``[batch_size , seq_len , embedding_dim]``
        """
        #x = x + self.pe[:x.size(0)]
        # [batch size , 1024 , 256 ]
        x = x + self.pe
        return self.dropout(x) , self.pe
        #return self.dropout(x)
        #return self.pe[:0]
    
class TransformerModel(nn.Module):
    def __init__(self, ntoken: int, d_model: int, nhead: int , d_hid: int, nlayers: int, dropout: float = 0.1):
        super().__init__()
        self.ntoken = ntoken
        self.model_type = 'Transformer'
        self.d_model = d_model
        
        self.enc_embedding = nn.Embedding(num_embeddings=ntoken ,embedding_dim= d_model)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model , nhead , d_hid , dropout) 
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        '''
        self.dec_embedding = nn.Embedding(num_embeddings=ntoken ,embedding_dim= d_model)
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead , dropout=dropout)
        self.decoder = nn.TransformerDecoder( self.decoder_layer  , num_layers=6)
        ''' 

        '''
        #self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)
        self.init_weights()
        '''

    def init_weights(self) -> None:
        initrange = 0.1
        #self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        #src = self.embedding(src) * math.sqrt(self.d_model)
        #print("trans forward " , src.shape)
        #src = self.pos_encoder(src)
        #output = self.transformer_encoder(src, src_mask)
        #output = self.linear(output)
        #return output
        #print("self.ntoken" , self.ntoken)
        #print("self.src" , src.shape)
        ''''''
        pos_idx = torch.arange(self.ntoken , device=src.device)

        pos_src = self.enc_embedding(pos_idx)* math.sqrt(self.d_model) + src
        pos_enc_src = self.encoder(pos_src )

        pos_dec_src = self.dec_embedding(pos_idx)* math.sqrt(self.d_model) + pos_enc_src
        #pos_dec_src = self.dec_embedding( torch.tensor([0,1],device=src.device) )* math.sqrt(self.d_model) + pos_enc_src
        dec_src = self.decoder( pos_enc_src  ,pos_dec_src)

        #print("dec_src" , dec_src.shape)    # [batch , token , hidden]
        return dec_src


    

class VerticalQueryTransformer(pl.LightningModule):    
    def __init__(self  ,  max_predict_count = 24 , hidden_out = 128 , class_num = 1 , log_folder = "__test" , num_classes = 1 , backbone_trainable =False, load_weight ="" , top_k = 20):
        #print(" input_size" ,  input_size)
        super().__init__()
        self.backbone = Resnet()
        self.out_scale = 8
        self.step_cols = 4        
        self.hidden_size = hidden_out
        self.max_predict_count = max_predict_count
        self.num_classes  = num_classes 
        self.top_k_num = top_k        

        self.fixed_pe = PositionalEncoding(hidden_out, 0.1 , 1024)

        #self.transformer = TransformerModel( ntoken= max_predict_count , d_model=hidden_out , nhead=8 , d_hid= 2048,nlayers=6 )
        self.transformer = nn.TransformerEncoderLayer(hidden_out ,  8 , 2048 , dropout= 0.1) 

        #self.box_head= nn.Linear( hidden_out , 6 )        
        self.vqt_box_head= nn.Linear( hidden_out , 5 )        
        self.vqt_cls_head= nn.Linear( hidden_out , class_num )        
        self.confidence_threshold = 0.5

        # loss
        self.box_cost = 1
        self.cls_cost = 20

        self.log_folder = create_folder(os.path.join(os.getcwd() , "output" , log_folder))
        #self.box_head.bias.data = torch.nn.Parameter(torch.tensor([0.3,0.2,0.2,0.3]))
        #self.box_head.weight.data.fill_(0)
        
        # Inference channels number from each block of the encoder
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 320, 190)
            c1, c2, c3, c4 = [b.shape[1] for b in self.backbone(dummy)] # resnet feature channel數
            #print("c1, c2, c3, c4" , c1, c2, c3, c4)
            c_last = (c1*8 + c2*4 + c3*2 + c4*1) // self.out_scale            
        self.v_reproj = nn.Conv2d(1024 , self.max_predict_count,kernel_size=1)
        self.reduce_height_module = GlobalHeightStage(c1, c2, c3, c4 , out_scale=self.out_scale , pretrain_weight= load_weight , freeze_model= not backbone_trainable)


    def forward(self ,x ):
        features = self.backbone(x) # [4 , c , h, w]      
        # vertical feature
        reduced_feats  = self.reduce_height_module(features , x.shape[3]//self.step_cols ) # [b , 1024 ,  256] width = 1024 , 256d latent code each.
        '''
        plt.imshow(reduced_feats[0].detach().cpu().numpy())
        plt.title("vertical feature")
        plt.show()        
        print("reduced_feats" , reduced_feats.shape)  # [b , 1024 , 256]
        '''
        
        # Add fixed PE
        embedded_feat , pe_pattern = self.fixed_pe(reduced_feats)
        '''
        print("pe_pattern " , pe_pattern.shape)
        plt.imshow(pe_pattern.detach().cpu().numpy())
        plt.title("PE Pattern")
        plt.show()

        plt.imshow(embedded_feat[0].detach().cpu().numpy())
        plt.title("PE embedded_feat")
        plt.show()
        '''

        #reduced_feats = self.v_reproj(reduced_feats.view(reduced_feats.shape[0] , 1024 , 256,-1)).view(reduced_feats.shape[0]  , 256 , 256)
        #print("reduced_feats" , reduced_feats.shape)
        output = self.transformer(embedded_feat)  # (b , 1024 , 256 )
        #print("encoder out " , output.shape)
        
        batch_size = output.shape[0]
        #output = output.view( batch_size , self.max_predict_count , -1 ) # (b , 90 , 64*out_scale )        

        out_box = self.vqt_box_head(output)
        out_cls = self.vqt_cls_head(output)#.view(batch_size , self.max_predict_count , -1 )

        '''
        '''
        clampped_box = torch.zeros((batch_size , self.max_predict_count , 5)).to(output.device)
        out_cls = F.interpolate(out_cls.view(batch_size,-1).unsqueeze(0) , self.max_predict_count)[0].view(batch_size , self.max_predict_count,1)
        
        clampped_box[:,:,0] =  F.interpolate( out_box[:,:,0].unsqueeze(0), self.max_predict_count)
        clampped_box[:,:,1] =  F.interpolate( out_box[:,:,1].unsqueeze(0), self.max_predict_count)
        clampped_box[:,:,2] =  F.interpolate( out_box[:,:,2].unsqueeze(0), self.max_predict_count)
        clampped_box[:,:,3] =  F.interpolate( out_box[:,:,3].unsqueeze(0), self.max_predict_count)
        clampped_box[:,:,4] =  F.interpolate( out_box[:,:,4].unsqueeze(0), self.max_predict_count)
        
        '''
        plt.imshow(out_cls[0].repeat(10,1).detach().cpu().numpy())
        plt.title("encoder cls output")
        plt.show()
        top_k = torch.topk(out_cls ,self.top_k_num , dim= 1 )
        top_k_idx = top_k[1].view(batch_size  , self.top_k_num)  # [b , top_k ]        
        #print("top_k_idx" , top_k_idx.shape)

        topk_box = out_box.gather(1 , top_k_idx.unsqueeze(-1).repeat(1,1,5))
        topk_cls = out_cls.gather(1 , top_k_idx.unsqueeze(-1).repeat(1,1,1))
        
        return topk_box , topk_cls , top_k_idx
        '''
        return clampped_box , out_cls
    
    #@torch.no_grad()
    def find_match(self, gt , pred):
        #print("gt"  , gt)
        gt_vec = torch.stack(gt).permute(1,0)
        pred_vec = torch.stack(pred).permute(1,0)
        loss_dist = torch.cdist(gt_vec , pred_vec)
        
        each_gt_pred_best_idx = torch.argmax(loss_dist, 0 )      
        #print("each_gt_pred_best_idx" , each_gt_pred_best_idx)          

        return gt_vec[each_gt_pred_best_idx] , pred_vec
        
    def pack_visualize(self, gt_u_b , gt_vtop_b , gt_vbtm_b , gt_du_b , gt_dvtop_b , dv_btm_b ):
        
        sizes = [len(t) for t in gt_u_b]               
        if isinstance(gt_u_b, torch.Tensor):
            us = gt_u_b.flatten().unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            us[1::2]+=gt_du_b.flatten()
            us = torch.split(us.view(-1,2) , sizes)

            tops = gt_vtop_b.flatten().unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            tops[1::2]=gt_dvtop_b.flatten()
            tops = torch.split(tops.view(-1,2) , sizes)

            btms = gt_vbtm_b.flatten().unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            btms[1::2]=dv_btm_b.flatten()
            btms = torch.split(btms.view(-1,2) , sizes)

        elif isinstance(gt_u_b, tuple) and all(isinstance(t, torch.Tensor) for t in gt_u_b):        
            us = torch.cat(gt_u_b).view(-1).unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            us[1::2]+=torch.cat(gt_du_b).view(-1)
            us = torch.split(us.view(-1,2) , sizes)

            tops = torch.cat(gt_vtop_b).view(-1).unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            tops[1::2]=torch.cat(gt_dvtop_b).view(-1)
            tops = torch.split(tops.view(-1,2) , sizes)

            btms = torch.cat(gt_vbtm_b).view(-1).unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            btms[1::2]=torch.cat(dv_btm_b).view(-1)
            btms = torch.split(btms.view(-1,2) , sizes)
        else:
            assert("Wrong Type.")
        
        return us , tops ,btms
        
        pass

    def training_step(self , input_b ,batch_idx , optimizer_idx):
        
        img = input_b['image']
        #h,w = img.shape[1:3]
        out_box , out_cls  = self.forward(img)  # [ batch , top_k , 5]   , [ batch , top_k , 1] 
        
        batch_size = out_box.shape[0]
               
        gt_u_b = unpad_data( input_b['u'])          
        gt_vtop_b =unpad_data(input_b['v_top'])
        gt_vbtm_b = unpad_data (input_b['v_btm'])
        gt_du_b = unpad_data(input_b['du'])
        gt_dvtop_b = unpad_data(input_b['dv_top'])
        gt_dv_btm_b = unpad_data(input_b['dv_btm'])

        #selected_gt_u_grad =  input_b['u_grad'].view(batch_size , 1024 , 1).gather(1 , top_k_idx.unsqueeze(-1).repeat(1,1,1)).view(batch_size , self.top_k_num)        

        total_loss = 0
        b_cnt = 0
        for u,vtop,vbtm,du,dvtop, dvbtm , pred ,cls_b,gt_cls in zip(gt_u_b , gt_vtop_b , gt_vbtm_b , gt_du_b , gt_dvtop_b , gt_dv_btm_b , out_box , out_cls , input_b['u_grad']):
        #for u,vtop,vbtm,du,dvtop, dvbtm , pred ,cls_b,gt_cls , k_idx in zip(gt_u_b , gt_vtop_b , gt_vbtm_b , gt_du_b , gt_dvtop_b , gt_dv_btm_b , out_box , out_cls ,selected_gt_u_grad , top_k_idx):
            
            # match            
            gt_box =  torch.vstack([vtop,vbtm,du ,dvtop , dvbtm]).permute(1,0)
            
            box_loss = torch.cdist( pred , gt_box , p=1)
            cls_loss = - F.softmax( cls_b , -1)

            #cost_matrix = box_loss * self.box_cost + cls_loss * self.cls_cost
            '''
            cost_matrix =  torch.cdist( F.sigmoid(cls_b).view(-1,1) , gt_u_b.view(-1,1))            
            cost_matrix = cost_matrix.detach().cpu().numpy()            
            row_idx  , col_idx = linear_sum_assignment(cost_matrix)                      
            '''
            
            row_idx = torch.round(u* self.max_predict_count ).to(torch.long).detach().cpu().numpy()           
            col_idx = torch.argsort(u).detach().cpu().numpy() #torch.arange(u.shape[0]).detach().cpu().numpy()           

            matched_cls_loss = F.binary_cross_entropy_with_logits(cls_b[row_idx].view(-1) , gt_cls[col_idx])
            
            total_loss += F.l1_loss(pred[row_idx] ,  gt_box[col_idx])*5 + F.binary_cross_entropy_with_logits(cls_b.view(-1), gt_cls  ) + matched_cls_loss
            
            with torch.no_grad():
                #if self.current_epoch % 5 == 0  :                
                if self.current_epoch % 5 == 0 and self.current_epoch > 0 :                
                    save_path =  os.path.join(self.log_folder , f"gt_ep_{self.current_epoch}-{self.global_step}-{b_cnt}" )
                    gt_us , gt_tops , gt_btms = self.pack_visualize(u.view(1 , -1 ) , vtop , vbtm , du , dvtop , dvbtm )
                    #print("gt_us , gt_tops , gt_btms" , gt_us , gt_tops , gt_btms)
                    vis_imgs = visualize_2d_single(gt_us , gt_tops , gt_btms , u_grad =  gt_cls.view(1 , -1 ), imgs= img[b_cnt] , title="GT",save_path=save_path )                
                    
                    save_path =  os.path.join(self.log_folder , f"pred_ep_{self.current_epoch}-{self.global_step}-{b_cnt}" )
                    pred_u = row_idx / self.max_predict_count                    
                    #pred_u = k_idx[row_idx] / 1024
                    #pred_u = pred_u.detach().cpu().numpy()
                    #print("pred_u" , pred_u)
                    pred_u = torch.from_numpy(pred_u.flatten()[np.newaxis,...]).to(u.device)
                    pred_us , pred_tops , pred_btms = self.pack_visualize(pred_u, pred[row_idx,0],pred[row_idx,1],pred[row_idx,2],pred[row_idx,3],pred[row_idx,4] )
                    #print("pred_us , pred_tops , pred_btms" , pred_us , pred_tops , pred_btms)
                    vis_imgs = visualize_2d_single(pred_us , pred_tops , pred_btms , u_grad = F.sigmoid(cls_b).view(1 , -1 ) , imgs=  img[b_cnt] , title="Pred" , save_path= save_path  )
                    
           
            b_cnt+=1
            pass        
        return total_loss
        pass    

    def configure_optimizers(self):
        backbone_opt = optim.Adam(self.backbone.parameters() , lr=0.00035)
        transforms_opt = optim.Adam(self.transformer.parameters() , lr=0.000035)

        return [backbone_opt , transforms_opt] , []


    pass

# Unit testing...

# Test
dm = CustomDataModule ( train_dir= f"../anno/test_visiable_10_no_cross.json" ,
                        test_dir= f"../anno/test_visiable_10_no_cross.json" , padding_count=256 , use_aug=False , c= 0.1
                       )
m = VerticalQueryTransformer(max_predict_count = 256 , hidden_out=256 , load_weight="D:/OneDrive/OneDrive - NTHU/Layout/Horizon/0912_all_bk.pth"  , backbone_trainable=True, top_k=90)
#img = torch.randn((3,3,1024,512))
#o = m(img)

#print(o)
trainer = pl.Trainer(accelerator='gpu' , devices=1 ,min_epochs=1, max_epochs=51 , precision=16 , fast_dev_run=False )
trainer.fit(m , dm)



position torch.Size([1024, 1])
div_term torch.Size([256])
pe torch.Size([1024, 256])


Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type                    | Params
-----------------------------------------------------------------
0 | backbone             | Resnet                  | 23.5 M
1 | fixed_pe             | PositionalEncoding      | 0     
2 | transformer          | TransformerEncoderLayer | 1.3 M 
3 | vqt_box_head         | Linear                  | 1.3 K 
4 | vqt_cls_head         |

Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s, loss=3.55, v_num=84]        

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 5:  20%|██        | 1/5 [00:02<00:08,  2.09s/it, loss=3.41, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 5:  40%|████      | 2/5 [00:04<00:06,  2.00s/it, loss=3.38, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 5:  60%|██████    | 3/5 [00:05<00:03,  1.98s/it, loss=3.19, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 5:  80%|████████  | 4/5 [00:07<00:01,  1.97s/it, loss=3.15, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:   0%|          | 0/5 [00:00<?, ?it/s, loss=2.52, v_num=84]       

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:  20%|██        | 1/5 [00:02<00:08,  2.22s/it, loss=2.48, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:  40%|████      | 2/5 [00:04<00:06,  2.18s/it, loss=2.46, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:  60%|██████    | 3/5 [00:06<00:04,  2.17s/it, loss=2.42, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:  80%|████████  | 4/5 [00:08<00:02,  2.14s/it, loss=2.45, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:   0%|          | 0/5 [00:00<?, ?it/s, loss=2.13, v_num=84]        

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:  20%|██        | 1/5 [00:02<00:08,  2.10s/it, loss=2.19, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:  40%|████      | 2/5 [00:04<00:06,  2.06s/it, loss=2.09, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:  60%|██████    | 3/5 [00:06<00:04,  2.05s/it, loss=2.09, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:  80%|████████  | 4/5 [00:08<00:02,  2.05s/it, loss=2.12, v_num=84]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 16: 100%|██████████| 5/5 [00:02<00:00,  1.89it/s, loss=2.15, v_num=84]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [8]:
import torch
horizon_path =r"D:/OneDrive/OneDrive - NTHU/Layout/Horizon/0912_all_bk.pth"
#models_dict = torch.load_s
checkpoint = torch.load(horizon_path ,  map_location="cpu")
print(checkpoint['state_dict'].keys())
#model.load_state_dict(checkpoint['model_state_dict'])
pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in m.state_dict()}
m.load_state_dict(pretrained_dict , strict=False)

odict_keys(['v_head.weight', 'v_head.bias', 'du_head.weight', 'du_head.bias', 'cls_head.weight', 'cls_head.bias', 'feature_extractor.encoder.conv1.1.weight', 'feature_extractor.encoder.bn1.weight', 'feature_extractor.encoder.bn1.bias', 'feature_extractor.encoder.bn1.running_mean', 'feature_extractor.encoder.bn1.running_var', 'feature_extractor.encoder.bn1.num_batches_tracked', 'feature_extractor.encoder.layer1.0.conv1.weight', 'feature_extractor.encoder.layer1.0.bn1.weight', 'feature_extractor.encoder.layer1.0.bn1.bias', 'feature_extractor.encoder.layer1.0.bn1.running_mean', 'feature_extractor.encoder.layer1.0.bn1.running_var', 'feature_extractor.encoder.layer1.0.bn1.num_batches_tracked', 'feature_extractor.encoder.layer1.0.conv2.1.weight', 'feature_extractor.encoder.layer1.0.bn2.weight', 'feature_extractor.encoder.layer1.0.bn2.bias', 'feature_extractor.encoder.layer1.0.bn2.running_mean', 'feature_extractor.encoder.layer1.0.bn2.running_var', 'feature_extractor.encoder.layer1.0.bn2.num_

_IncompatibleKeys(missing_keys=['backbone.encoder.conv1.weight', 'backbone.encoder.bn1.weight', 'backbone.encoder.bn1.bias', 'backbone.encoder.bn1.running_mean', 'backbone.encoder.bn1.running_var', 'backbone.encoder.layer1.0.conv1.weight', 'backbone.encoder.layer1.0.bn1.weight', 'backbone.encoder.layer1.0.bn1.bias', 'backbone.encoder.layer1.0.bn1.running_mean', 'backbone.encoder.layer1.0.bn1.running_var', 'backbone.encoder.layer1.0.conv2.weight', 'backbone.encoder.layer1.0.bn2.weight', 'backbone.encoder.layer1.0.bn2.bias', 'backbone.encoder.layer1.0.bn2.running_mean', 'backbone.encoder.layer1.0.bn2.running_var', 'backbone.encoder.layer1.0.conv3.weight', 'backbone.encoder.layer1.0.bn3.weight', 'backbone.encoder.layer1.0.bn3.bias', 'backbone.encoder.layer1.0.bn3.running_mean', 'backbone.encoder.layer1.0.bn3.running_var', 'backbone.encoder.layer1.0.downsample.0.weight', 'backbone.encoder.layer1.0.downsample.1.weight', 'backbone.encoder.layer1.0.downsample.1.bias', 'backbone.encoder.layer1

In [30]:
a = torch.rand(200).view(2,100,1)
#aa = a[:,:,0].unsqueeze(0)
#print(aa.shape)
b = F.interpolate(a.view(2,-1).unsqueeze(0), 10 )[0]
print(b.shape)

torch.Size([2, 10])


In [7]:
a = torch.rand(20).view(-1,1)
b = torch.rand(5).view(-1,1)
c = torch.cdist(a,b)
print(c)

tensor([[0.1426, 0.5408, 0.1219, 0.1724, 0.7682],
        [0.2389, 0.1592, 0.2596, 0.5540, 0.3866],
        [0.5219, 0.1237, 0.5426, 0.8369, 0.1037],
        [0.3582, 0.0399, 0.3789, 0.6733, 0.2673],
        [0.6216, 0.2235, 0.6423, 0.9367, 0.0039],
        [0.3501, 0.7483, 0.3295, 0.0351, 0.9757],
        [0.2889, 0.6871, 0.2683, 0.0261, 0.9145],
        [0.3629, 0.0352, 0.3836, 0.6780, 0.2626],
        [0.4082, 0.0101, 0.4289, 0.7233, 0.2173],
        [0.4492, 0.0511, 0.4699, 0.7643, 0.1763],
        [0.5382, 0.1401, 0.5589, 0.8533, 0.0873],
        [0.0368, 0.3614, 0.0575, 0.3518, 0.5888],
        [0.3789, 0.0193, 0.3995, 0.6939, 0.2467],
        [0.2068, 0.1913, 0.2275, 0.5219, 0.4187],
        [0.2395, 0.1587, 0.2602, 0.5545, 0.3861],
        [0.1033, 0.5014, 0.0826, 0.2118, 0.7288],
        [0.3085, 0.0896, 0.3292, 0.6236, 0.3170],
        [0.1729, 0.2252, 0.1936, 0.4880, 0.4526],
        [0.5994, 0.2012, 0.6201, 0.9144, 0.0262],
        [0.2217, 0.6199, 0.2010, 0.0933, 0.8473]])

In [29]:
#a = torch.rand(2,100,1)
a = torch.arange(400).view(2,100,2)
b = torch.arange(20).view(2,10)

print(a)
print(b)
b= b.unsqueeze(-1).repeat(1,1,2)
print(b)
#print("b unsqueeze",b.unsqueeze(-1))

c = a.gather(1, b)
print(c.shape)
print(c)


tensor([[[  0,   1],
         [  2,   3],
         [  4,   5],
         [  6,   7],
         [  8,   9],
         [ 10,  11],
         [ 12,  13],
         [ 14,  15],
         [ 16,  17],
         [ 18,  19],
         [ 20,  21],
         [ 22,  23],
         [ 24,  25],
         [ 26,  27],
         [ 28,  29],
         [ 30,  31],
         [ 32,  33],
         [ 34,  35],
         [ 36,  37],
         [ 38,  39],
         [ 40,  41],
         [ 42,  43],
         [ 44,  45],
         [ 46,  47],
         [ 48,  49],
         [ 50,  51],
         [ 52,  53],
         [ 54,  55],
         [ 56,  57],
         [ 58,  59],
         [ 60,  61],
         [ 62,  63],
         [ 64,  65],
         [ 66,  67],
         [ 68,  69],
         [ 70,  71],
         [ 72,  73],
         [ 74,  75],
         [ 76,  77],
         [ 78,  79],
         [ 80,  81],
         [ 82,  83],
         [ 84,  85],
         [ 86,  87],
         [ 88,  89],
         [ 90,  91],
         [ 92,  93],
         [ 94

In [None]:
a=torch.arange(5)
b=torch.arange(5)
c=torch.arange(5)

d = torch.vstack([a,b,c]).permute(1,0)
print(d)


In [None]:
from scipy.optimize import linear_sum_assignment
a = torch.tensor([ [0,1,2]  ,  [0,3,5] , [1,0,5] ]).to(torch.float32)
b = torch.tensor([ [0,1,2] , [1,0,5] ]).to(torch.float32)

cost = torch.cdist(b,a)
print(cost)
row , col = linear_sum_assignment(cost,)
print(row)
print(col)

In [None]:
x = torch.tensor([[0.7605, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7730, 0.5752, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7057, 0.5861, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8386, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8304, 0.7823, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7034, 0.5994, 0.5691, 0.5652, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6996, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8305, 0.7819, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8238, 0.7839, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
# Get the indices of non-zero elements
non_zero_indices = torch.nonzero(x)
print(non_zero_indices)
# Get the non-zero values
non_zero_values = x[non_zero_indices[:,0], non_zero_indices[:,1]]

unique = torch.unique(non_zero_indices[:,0] ,return_counts=True)
print("unique" , unique)
# Print the result
print(non_zero_values)
non_zero_values = torch.split(non_zero_values , tuple(unique[1]))
print("split non_zero_values" , non_zero_values)

def unpad_data( x :[Tensor] ) :
	non_zero_indices = torch.nonzero(x)
	print(non_zero_indices)
	# Get the non-zero values
	non_zero_values = x[non_zero_indices[:,0], non_zero_indices[:,1]]

	unique = torch.unique(non_zero_indices[:,0] ,return_counts=True)
	print("unique" , unique)
	# Print the result
	print(non_zero_values)
	non_zero_values = torch.split(non_zero_values , tuple(unique[1]))
	print("split non_zero_values" , non_zero_values)
	return non_zero_values

In [None]:
a = torch.tensor([[0.58 , 0.6] , [0.4] ] , )
b = torch.tensor([0.1 , 0.2] , )

c = a.repeat(2)
print(a.repeat(2))

In [None]:
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, tgt)

print(out.shape)

In [None]:
import torch

print(torch.rand(1)[0])