In [1]:
# Data Loader
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
import pytorch_lightning as pl
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
import json
from CustomDataset import * 
import sys
sys.path.append('../')
from config import *
from file_helper import *

def collate_fn(batch):
    return tuple(zip(*batch))
#=================================
#             Augmentation
#=================================

def gauss_noise_tensor(img):
    rand = torch.rand(1)[0]
    if rand < 0.5 and Horizon_AUG:
        sigma = rand *0.125
        out = img + sigma * torch.randn_like(img)
        return out
    return img

def blank(img):    
    return img

class CustomDataModule(pl.LightningDataModule):
    def __init__(self , train_dir , test_dir , batch_size = 2, num_workers = 0 , img_size=[IMG_WIDTH, IMG_HEIGHT] , use_aug = True ,padding_count = 24 ,c =0.1 ):
        super().__init__()
        self.train_dir = train_dir
        self.test_dir = test_dir
        
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.img_size = img_size      
        self.use_aug = use_aug
        self.padding_count  = padding_count
        self.c = c
        
        pass

    def prepare_data(self) -> None:
        # Download dataset
        pass

    def setup(self, stage):
        # Create dataset...          
                
        self.entire_dataset = CustomDataset(self.train_dir  , use_aug= self.use_aug , padding_count= self.padding_count , c=self.c)
        self.train_ds , self.val_ds = random_split(self.entire_dataset , [0.9, 0.1])        
        self.test_ds = CustomDataset(self.test_dir  , use_aug= False)
        
        pass

    # ToDo: Reture Dataloader...
    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return DataLoader(self.train_ds , batch_size= self.batch_size , num_workers= self.num_workers , shuffle=True)
    
    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.val_ds , batch_size= self.batch_size , num_workers= self.num_workers , shuffle=False)
    
    def test_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.test_ds , batch_size= self.batch_size , num_workers= self.num_workers , shuffle=False)

    pass


# Test
dm = CustomDataModule ( train_dir= f"../anno/test_visiable_10_no_cross.json" ,
                       test_dir= f"../anno/test_visiable_10_no_cross.json" , padding_count=256
                       )

  from .autonotebook import tqdm as notebook_tqdm


d:\Projects\Layout\NTHU_CGV_Layout_exp\ours
d:\Projects\Layout\NTHU_CGV_Layout_exp\Horizon_and_SAM\Horizon


In [2]:
from torch import Tensor
def unpad_data( x :[Tensor] ) :
	non_zero_indices = torch.nonzero(x)
	#print(non_zero_indices)
	# Get the non-zero values
	non_zero_values = x[non_zero_indices[:,0], non_zero_indices[:,1]]

	unique = torch.unique(non_zero_indices[:,0] ,return_counts=True)
	#print("unique" , unique)
	# Print the result
	#print(non_zero_values)
	non_zero_values = torch.split(non_zero_values , tuple(unique[1]))
	#print("split non_zero_values" , non_zero_values)
	return non_zero_values

In [12]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import Any
import pytorch_lightning as pl
from config import *
import torchvision.models as models
from torchvision.ops import MLP
import math
from torch import Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment

# Model
class ConvCompressH(nn.Module):
    ''' Reduce feature height by factor of two '''
    def __init__(self, in_c, out_c, ks=3):
        super(ConvCompressH, self).__init__()
        assert ks % 2 == 1
        self.layers = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=ks, stride=(2, 1), padding=ks//2),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.layers(x)

class GlobalHeightConv(nn.Module):
    def __init__(self, in_c, out_c):
        super(GlobalHeightConv, self).__init__()
        self.layer = nn.Sequential(
            ConvCompressH(in_c, in_c//2),
            ConvCompressH(in_c//2, in_c//2),
            ConvCompressH(in_c//2, in_c//4),
            ConvCompressH(in_c//4, out_c),
        )

    def forward(self, x, out_w):
        x = self.layer(x)
        assert out_w % x.shape[3] == 0
        factor = out_w // x.shape[3]
        x = torch.cat([x[..., -1:], x, x[..., :1]], 3)
        x = F.interpolate(x, size=(x.shape[2], out_w + 2 * factor), mode='bilinear', align_corners=False)
        x = x[..., factor:-factor]        
        return x
class Resnet(nn.Module):
    def __init__(self, backbone='resnet50', pretrained=True):
        super(Resnet, self).__init__()
        #assert backbone in ENCODER_RESNET
        self.encoder = getattr(models, backbone)(pretrained=pretrained)
        del self.encoder.fc, self.encoder.avgpool

    def forward(self, x):
        features = []
        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)
        x = self.encoder.relu(x)
        x = self.encoder.maxpool(x)

        x = self.encoder.layer1(x);  features.append(x)  # 1/4
        x = self.encoder.layer2(x);  features.append(x)  # 1/8
        x = self.encoder.layer3(x);  features.append(x)  # 1/16
        x = self.encoder.layer4(x);  features.append(x)  # 1/32
        return features

    def list_blocks(self):
        lst = [m for m in self.encoder.children()]
        block0 = lst[:4]
        block1 = lst[4:5]
        block2 = lst[5:6]
        block3 = lst[6:7]
        block4 = lst[7:8]
        return block0, block1, block2, block3, block4
    
class GlobalHeightStage(nn.Module):
    def __init__(self, c1, c2, c3, c4, out_scale=8 , pretrain_weight= ""):
        ''' Process 4 blocks from encoder to single multiscale features '''
        super(GlobalHeightStage, self).__init__()
        self.cs = c1, c2, c3, c4
        self.out_scale = out_scale
        self.ghc_lst = nn.ModuleList([
            GlobalHeightConv(c1, c1//out_scale),
            GlobalHeightConv(c2, c2//out_scale),
            GlobalHeightConv(c3, c3//out_scale),
            GlobalHeightConv(c4, c4//out_scale),
        ])

        if(pretrain_weight!=""):

            pass

    def forward(self, conv_list, out_w):
        assert len(conv_list) == 4
        bs = conv_list[0].shape[0]
        feature = torch.cat([
            f(x, out_w).reshape(bs, -1, out_w)
            for f, x, out_c in zip(self.ghc_lst, conv_list, self.cs)
        ], dim=1)
        return feature

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    
class TransformerModel(nn.Module):
    def __init__(self, ntoken: int, d_model: int, nhead: int , d_hid: int, nlayers: int, dropout: float = 0.1):
        super().__init__()
        self.ntoken = ntoken
        self.model_type = 'Transformer'
        self.d_model = d_model
        
        self.enc_embedding = nn.Embedding(num_embeddings=ntoken ,embedding_dim= d_model)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model , nhead , d_hid , dropout) 
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)

        self.dec_embedding = nn.Embedding(num_embeddings=ntoken ,embedding_dim= d_model)
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead , dropout=dropout)
        self.decoder = nn.TransformerDecoder( self.decoder_layer  , num_layers=6)

        '''
        #self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, ntoken)
        self.init_weights()
        '''

    def init_weights(self) -> None:
        initrange = 0.1
        #self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
        #src = self.embedding(src) * math.sqrt(self.d_model)
        #print("trans forward " , src.shape)
        #src = self.pos_encoder(src)
        #output = self.transformer_encoder(src, src_mask)
        #output = self.linear(output)
        #return output
        #print("self.ntoken" , self.ntoken)
        #print("self.src" , src.shape)
        pos_idx = torch.arange(self.ntoken , device=src.device)

        pos_src = self.enc_embedding(pos_idx)* math.sqrt(self.d_model) + src
        pos_enc_src = self.encoder(pos_src )

        pos_dec_src = self.dec_embedding(pos_idx)* math.sqrt(self.d_model) + pos_enc_src
        #pos_dec_src = self.dec_embedding( torch.tensor([0,1],device=src.device) )* math.sqrt(self.d_model) + pos_enc_src
        dec_src = self.decoder( pos_enc_src  ,pos_dec_src)

        #print("dec_src" , dec_src.shape)    # [batch , token , hidden]
        return dec_src


    

class VerticalQueryTransformer(pl.LightningModule):    
    def __init__(self  ,  max_predict_count = 24 , hidden_out = 128 , class_num = 1 , log_folder = "__test" , num_classes = 1 , load_weight =""):
        #print(" input_size" ,  input_size)
        super().__init__()
        self.backbone = Resnet()
        self.out_scale = 8
        self.step_cols = 4        
        self.hidden_size = hidden_out
        self.max_predict_count = max_predict_count
        self.num_classes  = num_classes 

        self.transformer = TransformerModel( ntoken= max_predict_count , d_model=hidden_out , nhead=8 , d_hid= 2048,nlayers=6 )

        #self.box_head= nn.Linear( hidden_out , 6 )        
        self.vqt_box_head= nn.Linear( hidden_out , 5 )        
        self.vqt_cls_head= nn.Linear( hidden_out , class_num )        
        self.confidence_threshold = 0.5

        # loss
        self.box_cost = 1
        self.cls_cost = 20

        self.log_folder = create_folder(os.path.join(os.getcwd() , "output" , log_folder))
        #self.box_head.bias.data = torch.nn.Parameter(torch.tensor([0.3,0.2,0.2,0.3]))
        #self.box_head.weight.data.fill_(0)
        
        # Inference channels number from each block of the encoder
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 320, 190)
            c1, c2, c3, c4 = [b.shape[1] for b in self.backbone(dummy)] # resnet feature channel數
            #print("c1, c2, c3, c4" , c1, c2, c3, c4)
            c_last = (c1*8 + c2*4 + c3*2 + c4*1) // self.out_scale            
        self.reduce_height_module = GlobalHeightStage(c1, c2, c3, c4 , out_scale=self.out_scale , pretrain_weight= "")
        self.v_reproj = nn.Conv2d(1024 , self.max_predict_count,kernel_size=1)
        
        if(load_weight != ""):
            checkpoint = torch.load(load_weight ,  map_location="cpu")
            pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in self.state_dict()}
            self.load_state_dict(pretrained_dict , strict=False)
            '''
            for i, (name, param) in enumerate(self.named_parameters()):
                if "reduce_height_module" in name:
                    param.requires_grad = False
            '''
            pass


    def forward(self ,x ):
        features = self.backbone(x) # [4 , c , h, w]
        
        reduced_feats  = self.reduce_height_module(features , x.shape[3]//self.step_cols ) # [b , 1024 ,  256] width = 1024 , 256d latent code each.
        
        #print("reduced_feats out" , reduced_feats.shape) # [b , 1024 , 256]        
        reduced_feats = self.v_reproj(reduced_feats.view(reduced_feats.shape[0] , 1024 , 256,-1)).view(reduced_feats.shape[0]  , 256 , 256)
        #print("reduced_feats" , reduced_feats.shape)
        output = self.transformer(reduced_feats)  # (b , max_pred , 90 )

        batch_size = output.shape[0]
        output = output.view( batch_size , self.max_predict_count , -1 ) # (b , 90 , 64*out_scale )        

        out_box = self.vqt_box_head(output)
        out_cls = self.vqt_cls_head(output).view(batch_size , self.max_predict_count , -1 )
        
        #print("out_box" , out_box.shape) 
        #print("out_cls" , out_cls.shape) 
        return out_box , out_cls
    
    #@torch.no_grad()
    def find_match(self, gt , pred):
        #print("gt"  , gt)
        gt_vec = torch.stack(gt).permute(1,0)
        pred_vec = torch.stack(pred).permute(1,0)
        loss_dist = torch.cdist(gt_vec , pred_vec)
        
        each_gt_pred_best_idx = torch.argmax(loss_dist, 0 )      
        #print("each_gt_pred_best_idx" , each_gt_pred_best_idx)          

        return gt_vec[each_gt_pred_best_idx] , pred_vec
        
    def pack_visualize(self, gt_u_b , gt_vtop_b , gt_vbtm_b , gt_du_b , gt_dvtop_b , dv_btm_b ):
        
        sizes = [len(t) for t in gt_u_b]               
        if isinstance(gt_u_b, torch.Tensor):
            us = gt_u_b.flatten().unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            us[1::2]+=gt_du_b.flatten()
            us = torch.split(us.view(-1,2) , sizes)

            tops = gt_vtop_b.flatten().unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            tops[1::2]=gt_dvtop_b.flatten()
            tops = torch.split(tops.view(-1,2) , sizes)

            btms = gt_vbtm_b.flatten().unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            btms[1::2]=dv_btm_b.flatten()
            btms = torch.split(btms.view(-1,2) , sizes)

        elif isinstance(gt_u_b, tuple) and all(isinstance(t, torch.Tensor) for t in gt_u_b):        
            us = torch.cat(gt_u_b).view(-1).unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            us[1::2]+=torch.cat(gt_du_b).view(-1)
            us = torch.split(us.view(-1,2) , sizes)

            tops = torch.cat(gt_vtop_b).view(-1).unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            tops[1::2]=torch.cat(gt_dvtop_b).view(-1)
            tops = torch.split(tops.view(-1,2) , sizes)

            btms = torch.cat(gt_vbtm_b).view(-1).unsqueeze(0).repeat(2, 1).permute(1,0).reshape(-1)
            btms[1::2]=torch.cat(dv_btm_b).view(-1)
            btms = torch.split(btms.view(-1,2) , sizes)
        else:
            assert("Wrong Type.")
        
        return us , tops ,btms
        
        pass

    def training_step(self , input_b ,batch_idx , optimizer_idx):
        
        img = input_b['image']        

        img = input_b['image']
        #h,w = img.shape[1:3]
        out_box , out_cls = self.forward(img)  # [ batch , n , 5]
        #out_cls =  out_cls.view(batch_size * self.max_predict_count , self.num_classes) # [batch size * max count , num_classes]
        #print("out_box" , out_box.shape)
        #print("out_cls" , out_cls.shape)
        batch_size = out_box.shape[0]
               
        gt_u_b = unpad_data( input_b['u'])          
        gt_vtop_b =unpad_data(input_b['v_top'])
        gt_vbtm_b = unpad_data (input_b['v_btm'])
        gt_du_b = unpad_data(input_b['du'])
        gt_dvtop_b = unpad_data(input_b['dv_top'])
        gt_dv_btm_b = unpad_data(input_b['dv_btm'])
        

        total_loss = 0
        b_cnt = 0
        for u,vtop,vbtm,du,dvtop, dvbtm , pred ,cls_b,gt_cls in zip(gt_u_b , gt_vtop_b , gt_vbtm_b , gt_du_b , gt_dvtop_b , gt_dv_btm_b , out_box , out_cls , input_b['u_grad']):
            # match            
            gt_box =  torch.vstack([vtop,vbtm,du ,dvtop , dvbtm]).permute(1,0)
            
            box_loss = torch.cdist( pred , gt_box , p=1)
            cls_loss = - F.softmax( cls_b , -1)
            cost_matrix = box_loss*self.box_cost + cls_loss * self.cls_cost
            cost_matrix = cost_matrix.detach().cpu().numpy()
            
            row_idx  , col_idx = linear_sum_assignment(cost_matrix)
            #print("idx" , row_idx , col_idx)
            #print("matched out" , pred[row_idx])
            #print("matched gt" , gt_box[col_idx])
            
            matched_cls_loss = F.binary_cross_entropy_with_logits(cls_b[row_idx].view(-1) , gt_cls[col_idx])
            total_loss += F.l1_loss(pred[row_idx] ,  gt_box[col_idx]) + F.binary_cross_entropy_with_logits(cls_b.view(-1), gt_cls  ) + matched_cls_loss
            #total_loss += F.l1_loss(pred[row_idx] ,  gt_box[col_idx]) + matched_cls_loss

            #cls_loss = F.binary_cross_entropy_with_logits( out_cls.view(-1 , self.max_predict_count) ,)
            #matched_gt , matched_pred =  self.find_match( (u,vtop,vbtm,du,dvtop,btm , input_b['u_grad']),(pred[:,0],pred[:,1],pred[:,2],pred[:,3],pred[:,4],pred[:,5]))
            with torch.no_grad():
                #if self.current_epoch % 5 == 0  :                
                if self.current_epoch % 5 == 0 and self.current_epoch > 0 :                
                    save_path = create_folder( os.path.join(self.log_folder , f"gt_ep_{self.current_epoch}-{b_cnt}" ))
                    gt_us , gt_tops , gt_btms = self.pack_visualize(u.view(1 , -1 ) , vtop , vbtm , du , dvtop , dvbtm )
                    #print("gt_us , gt_tops , gt_btms" , gt_us , gt_tops , gt_btms)
                    vis_imgs = visualize_2d_single(gt_us , gt_tops , gt_btms , u_grad = gt_cls.view(1 , -1 ), imgs= img[b_cnt] , title="GT",save_path=save_path )                
                    
                    save_path = create_folder( os.path.join(self.log_folder , f"pred_ep_{self.current_epoch}-{b_cnt}" ))
                    pred_u = row_idx / self.max_predict_count
                    pred_u = torch.from_numpy(pred_u.flatten()[np.newaxis,...]).to(u.device)
                    #print("pred_u" , pred_u)
                    pred_us , pred_tops , pred_btms = self.pack_visualize(pred_u, pred[row_idx,0],pred[row_idx,1],pred[row_idx,2],pred[row_idx,3],pred[row_idx,4] )
                    #print("pred_us , pred_tops , pred_btms" , pred_us , pred_tops , pred_btms)
                    vis_imgs = visualize_2d_single(pred_us , pred_tops , pred_btms , u_grad = cls_b.view(1 , -1 ) , imgs=  img[b_cnt] , title="Pred" , save_path= save_path  )
                    
            #print("matched_gt" , matched_gt)
            #print("matched_pred" , matched_pred)
            #print("===================")
            #l1_loss += F.l1_loss(matched_gt , matched_pred)
            b_cnt+=1
            pass
        '''
        gt_vtop_b =torch.cat(unpad_data(input_b['v_top'])).view( -1)
        gt_vbtm_b = torch.cat(unpad_data (input_b['v_btm'])).view( -1)
        gt_du_b = torch.cat(unpad_data(input_b['du'])).view( -1)
        gt_dvtop_b = torch.cat(unpad_data(input_b['dv_top'])).view( -1)
        gt_dv_btm_b = torch.cat(unpad_data(input_b['dv_btm'])).view( -1)
        gt_box =  torch.vstack(
            [gt_vtop_b,
             gt_vbtm_b,
             gt_du_b ,
             gt_dvtop_b , 
             gt_dv_btm_b
             ]).permute(1,0)
        print("gt_box" , gt_box)

        # cost matrix        
        out_cls =  out_cls.view(batch_size * self.max_predict_count , self.num_classes) # [batch size * max count , num_classes]
        out_box = out_box.view(-1 , 5 ) # [batch size * max count , 5]

        gt_cls = torch.cat([ v for v  in input_b['u_grad']])
        #gt_box = gt_box
        print("input_b['u_grad']" , input_b['u_grad'].shape)

        box_loss = torch.cdist(out_box ,gt_box  , p =1)
        print("box_loss" , box_loss.shape)
        cls_loss = -F.softmax(out_cls , -1)

        cost_matrix = box_loss * self.box_cost + cls_loss*self.cls_cost 
        print("cost_matrix" , cost_matrix.shape)
        cost_matrix = cost_matrix.view(batch_size, self.max_predict_count, -1).detach().cpu().numpy()
        sizes = [self.max_predict_count for v in input_b['u_grad']]       
        print("cost_matrix" , cost_matrix.shape)

        #indices =[]
        matched_gt_idx=[]
        matched_out_idx=[]
        for i , cost_b in enumerate(cost_matrix):
            c = linear_sum_assignment(cost_b)
            print("c" , c)
            #indices.append(c[0] )
            matched_out_idx.append(c[0])
            matched_gt_idx.append(c[1])

            #_idx = torch.tensor(c[0])
            #print("_idx",_idx)
            #print("gt box" , gt_box[ _idx])
            #print("matched box" , out_box[ _idx] )
        #indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
        #print("indices" , indices)
        #indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
        #print("indices" , indices)

        #matched_gt = 
        box_loss = 0
        cls_loss = 0
        i=0
        for gt_idx , out_idx in zip( matched_gt_idx , matched_out_idx):
            box_loss+= F.l1_loss( gt_box[gt_idx+ self.max_predict_count*i ] , out_box[out_idx + self.max_predict_count*i ] )
            i+=1
            pass
        '''


        # match cls
        '''
        out_idx = torch.argwhere(out_cls.view(-1) > self.confidence_threshold ).view(-1)
        batch_idx = torch.floor(out_idx / self.max_predict_count).to(torch.int64)
        batch_split_cnt = torch.unique(batch_idx , return_counts=True)[1]
        print("batch_split_cnt" , batch_split_cnt)

        out_idx = torch.argwhere(out_cls > self.confidence_threshold )
        #print("out_idx" ,  out_idx.shape , out_idx)
        
        out_box_over_threshold = out_box[out_idx[:, 0],out_idx[:, 0],:]
        #print("out_box_over_threshold" ,  out_box_over_threshold.shape)
        pred_box = torch.zeros( out_box_over_threshold.shape[0] , 6,device=img.device)        
        pred_box[:,1:] = out_box_over_threshold
        pred_box [:, 0 ]  = (out_idx[:,1]+0.5) / self.max_predict_count
        #print("pred_box" ,  pred_box.shape)

        batch_split_cnt = torch.unique(out_idx[:,0] , return_counts=True)[1]
        #print("batch_split_cnt" , batch_split_cnt)
        pred_box_b = []        
        prev = 0
        for idx_cnt_b in batch_split_cnt:
            pred_box_b.append(pred_box[prev : idx_cnt_b])
            prev +=idx_cnt_b            
            pass
        '''
        '''
        #pred_box = out_box.view(-1,6)[out_idx]
        out_box_over_threshold = out_box.view(-1,5)[out_idx]
        pred_box = torch.zeros( out_box_over_threshold.shape[0] , 6)        
        pred_box [:,1:]   = out_box_over_threshold
        pred_box [:, 0 ]  = (out_idx+0.5) / w 
        #pred_box = out_box.view(-1,5)[out_idx]  # no u
        pred_box = torch.split(pred_box , tuple(batch_split_cnt) )
        print("pred_box" , len(pred_box) , pred_box[0].shape)
        '''

        '''
        
        # Log every 5 epochs
        with torch.no_grad():
            #if self.current_epoch % 5 == 0 and self.current_epoch > 0 :                
            if self.current_epoch % 5 == 0  :                
                save_path = create_folder( os.path.join(self.log_folder , f"ep_{self.current_epoch}" ))
                gt_us , gt_tops , gt_btms = self.pack_visualize(gt_u_b , gt_vtop_b , gt_vbtm_b , gt_du_b , gt_dvtop_b , gt_dv_btm_b )
                vis_imgs = visualize_2d(gt_us , gt_tops , gt_btms , u_grad = input_b['u_grad'] , imgs= img , title="GT",save_path=save_path )                

                pred_us , pred_tops , pred_btms = self.pack_visualize(out_box[:,0],out_box[:,1],out_box[:,2],out_box[:,3],out_box[:,4],out_box[:,5] )
                vis_imgs = visualize_2d(pred_us , pred_tops , pred_btms , u_grad = out_cls.view(batch_size , -1) , imgs= img , title="Pred" , save_path= save_path )
                
        print("gt_u" , gt_u_b)
        print("gt_vtop" , gt_vtop_b)
        print("gt_vbtm" , gt_vbtm_b)
        print("gt_du" , gt_du_b)
        print("gt_dvtop" , gt_dvtop_b)
        print("dv_btm" , dv_btm_b)

        l1_loss = 0
        for u,vtop,vbtm,du,dvtop,btm , pred in zip(gt_u_b , gt_vtop_b , gt_vbtm_b , gt_du_b , gt_dvtop_b , gt_dv_btm_b , pred_box_b):
            # match
            #print("pred" , pred.shape)
            #print("gt u" , u.shape)
            cls_loss = F.binary_cross_entropy_with_logits( out_cls.view(-1 , self.max_predict_count) ,)
            matched_gt , matched_pred =  self.find_match( (u,vtop,vbtm,du,dvtop,btm , input_b['u_grad']),(pred[:,0],pred[:,1],pred[:,2],pred[:,3],pred[:,4],pred[:,5]))
           
            print("matched_gt" , matched_gt)
            print("matched_pred" , matched_pred)
            print("===================")
            l1_loss += F.l1_loss(matched_gt , matched_pred)
            pass

        # loss:
        #box_loss = 

        #cls_loss = F.binary_cross_entropy_with_logits( out_cls.view(-1 , self.max_predict_count) , input_b['u_grad'])
        total_loss = l1_loss + cls_loss 
        '''
        return total_loss
        pass    

    def configure_optimizers(self):
        backbone_opt = optim.Adam(self.backbone.parameters() , lr=0.00035)
        transforms_opt = optim.Adam(self.transformer.parameters() , lr=0.000035)

        return [backbone_opt , transforms_opt] , []


    pass

# Unit testing...

# Test
dm = CustomDataModule ( train_dir= f"../anno/test_visiable_10_no_cross.json" ,
                        test_dir= f"../anno/test_visiable_10_no_cross.json" , padding_count=256 , use_aug=False , c= 0.95
                       )
m = VerticalQueryTransformer(max_predict_count = 256 , hidden_out=256 , load_weight="D:/OneDrive/OneDrive - NTHU/Layout/Horizon/0912_all_bk.pth")
#img = torch.randn((3,3,1024,512))
#o = m(img)

#print(o)
trainer = pl.Trainer(accelerator='gpu' , devices=1 ,min_epochs=1, max_epochs=51 , precision=16 , fast_dev_run=False )
trainer.fit(m , dm)



Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type              | Params
-----------------------------------------------------------
0 | backbone             | Resnet            | 23.5 M
1 | transformer          | TransformerModel  | 20.4 M
2 | vqt_box_head         | Linear            | 1.3 K 
3 | vqt_cls_head         | Linear            | 257   
4 | reduce_height_module | GlobalHeightStage | 45.5 M
5 | v_re

Epoch 5:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.7, v_num=55]         

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 5:  20%|██        | 1/5 [00:02<00:10,  2.57s/it, loss=1.74, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 5:  40%|████      | 2/5 [00:04<00:07,  2.47s/it, loss=1.68, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 5:  60%|██████    | 3/5 [00:07<00:04,  2.44s/it, loss=1.7, v_num=55] 

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 5:  80%|████████  | 4/5 [00:09<00:02,  2.43s/it, loss=1.68, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.68, v_num=55]       

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:  20%|██        | 1/5 [00:02<00:10,  2.70s/it, loss=1.77, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:  40%|████      | 2/5 [00:05<00:07,  2.53s/it, loss=1.74, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:  60%|██████    | 3/5 [00:08<00:05,  2.91s/it, loss=1.78, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 10:  80%|████████  | 4/5 [00:11<00:02,  2.80s/it, loss=1.74, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.72, v_num=55]        

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:  20%|██        | 1/5 [00:02<00:10,  2.57s/it, loss=1.74, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:  40%|████      | 2/5 [00:05<00:07,  2.52s/it, loss=1.7, v_num=55] 

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:  60%|██████    | 3/5 [00:07<00:04,  2.46s/it, loss=1.67, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 15:  80%|████████  | 4/5 [00:09<00:02,  2.42s/it, loss=1.67, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 20:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.66, v_num=55]        

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 20:  20%|██        | 1/5 [00:02<00:10,  2.58s/it, loss=1.67, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 20:  40%|████      | 2/5 [00:05<00:07,  2.52s/it, loss=1.67, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 20:  60%|██████    | 3/5 [00:07<00:04,  2.46s/it, loss=1.65, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 20:  80%|████████  | 4/5 [00:09<00:02,  2.42s/it, loss=1.66, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 25:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.61, v_num=55]        

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 25:  20%|██        | 1/5 [00:02<00:09,  2.48s/it, loss=1.61, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 25:  40%|████      | 2/5 [00:04<00:07,  2.41s/it, loss=1.62, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 25:  60%|██████    | 3/5 [00:07<00:04,  2.44s/it, loss=1.64, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 25:  80%|████████  | 4/5 [00:09<00:02,  2.41s/it, loss=1.65, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 30:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.62, v_num=55]        

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 30:  20%|██        | 1/5 [00:02<00:10,  2.51s/it, loss=1.62, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 30:  40%|████      | 2/5 [00:04<00:07,  2.43s/it, loss=1.68, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 30:  60%|██████    | 3/5 [00:07<00:04,  2.41s/it, loss=1.68, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 30:  80%|████████  | 4/5 [00:09<00:02,  2.40s/it, loss=1.69, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 35:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.56, v_num=55]        

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 35:  20%|██        | 1/5 [00:02<00:10,  2.71s/it, loss=1.51, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 35:  40%|████      | 2/5 [00:05<00:07,  2.51s/it, loss=1.52, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 35:  60%|██████    | 3/5 [00:07<00:04,  2.45s/it, loss=1.58, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 35:  80%|████████  | 4/5 [00:09<00:02,  2.42s/it, loss=1.59, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 40:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.6, v_num=55]         

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 40:  20%|██        | 1/5 [00:02<00:09,  2.36s/it, loss=1.6, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 40:  40%|████      | 2/5 [00:04<00:07,  2.33s/it, loss=1.61, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 40:  60%|██████    | 3/5 [00:06<00:04,  2.33s/it, loss=1.65, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 40:  80%|████████  | 4/5 [00:09<00:02,  2.41s/it, loss=1.61, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 45:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.58, v_num=55]        

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 45:  20%|██        | 1/5 [00:02<00:09,  2.50s/it, loss=1.61, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 45:  40%|████      | 2/5 [00:04<00:07,  2.40s/it, loss=1.64, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 45:  60%|██████    | 3/5 [00:07<00:04,  2.39s/it, loss=1.6, v_num=55] 

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 45:  80%|████████  | 4/5 [00:09<00:02,  2.37s/it, loss=1.58, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 50:   0%|          | 0/5 [00:00<?, ?it/s, loss=1.59, v_num=55]        

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 50:  20%|██        | 1/5 [00:02<00:10,  2.56s/it, loss=1.62, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 50:  40%|████      | 2/5 [00:04<00:07,  2.42s/it, loss=1.58, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 50:  60%|██████    | 3/5 [00:07<00:04,  2.40s/it, loss=1.57, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 50:  80%|████████  | 4/5 [00:09<00:02,  2.39s/it, loss=1.58, v_num=55]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch 50: 100%|██████████| 5/5 [00:11<00:00,  2.27s/it, loss=1.59, v_num=55]

`Trainer.fit` stopped: `max_epochs=51` reached.


Epoch 50: 100%|██████████| 5/5 [00:16<00:00,  3.39s/it, loss=1.59, v_num=55]


In [8]:
import torch
horizon_path =r"D:/OneDrive/OneDrive - NTHU/Layout/Horizon/0912_all_bk.pth"
#models_dict = torch.load_s
checkpoint = torch.load(horizon_path ,  map_location="cpu")
print(checkpoint['state_dict'].keys())
#model.load_state_dict(checkpoint['model_state_dict'])
pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in m.state_dict()}
m.load_state_dict(pretrained_dict , strict=False)

odict_keys(['v_head.weight', 'v_head.bias', 'du_head.weight', 'du_head.bias', 'cls_head.weight', 'cls_head.bias', 'feature_extractor.encoder.conv1.1.weight', 'feature_extractor.encoder.bn1.weight', 'feature_extractor.encoder.bn1.bias', 'feature_extractor.encoder.bn1.running_mean', 'feature_extractor.encoder.bn1.running_var', 'feature_extractor.encoder.bn1.num_batches_tracked', 'feature_extractor.encoder.layer1.0.conv1.weight', 'feature_extractor.encoder.layer1.0.bn1.weight', 'feature_extractor.encoder.layer1.0.bn1.bias', 'feature_extractor.encoder.layer1.0.bn1.running_mean', 'feature_extractor.encoder.layer1.0.bn1.running_var', 'feature_extractor.encoder.layer1.0.bn1.num_batches_tracked', 'feature_extractor.encoder.layer1.0.conv2.1.weight', 'feature_extractor.encoder.layer1.0.bn2.weight', 'feature_extractor.encoder.layer1.0.bn2.bias', 'feature_extractor.encoder.layer1.0.bn2.running_mean', 'feature_extractor.encoder.layer1.0.bn2.running_var', 'feature_extractor.encoder.layer1.0.bn2.num_

_IncompatibleKeys(missing_keys=['backbone.encoder.conv1.weight', 'backbone.encoder.bn1.weight', 'backbone.encoder.bn1.bias', 'backbone.encoder.bn1.running_mean', 'backbone.encoder.bn1.running_var', 'backbone.encoder.layer1.0.conv1.weight', 'backbone.encoder.layer1.0.bn1.weight', 'backbone.encoder.layer1.0.bn1.bias', 'backbone.encoder.layer1.0.bn1.running_mean', 'backbone.encoder.layer1.0.bn1.running_var', 'backbone.encoder.layer1.0.conv2.weight', 'backbone.encoder.layer1.0.bn2.weight', 'backbone.encoder.layer1.0.bn2.bias', 'backbone.encoder.layer1.0.bn2.running_mean', 'backbone.encoder.layer1.0.bn2.running_var', 'backbone.encoder.layer1.0.conv3.weight', 'backbone.encoder.layer1.0.bn3.weight', 'backbone.encoder.layer1.0.bn3.bias', 'backbone.encoder.layer1.0.bn3.running_mean', 'backbone.encoder.layer1.0.bn3.running_var', 'backbone.encoder.layer1.0.downsample.0.weight', 'backbone.encoder.layer1.0.downsample.1.weight', 'backbone.encoder.layer1.0.downsample.1.bias', 'backbone.encoder.layer1

In [None]:
a=torch.arange(5)
b=torch.arange(5)
c=torch.arange(5)

d = torch.vstack([a,b,c]).permute(1,0)
print(d)


In [None]:
from scipy.optimize import linear_sum_assignment
a = torch.tensor([ [0,1,2]  ,  [0,3,5] , [1,0,5] ]).to(torch.float32)
b = torch.tensor([ [0,1,2] , [1,0,5] ]).to(torch.float32)

cost = torch.cdist(b,a)
print(cost)
row , col = linear_sum_assignment(cost,)
print(row)
print(col)

In [None]:
x = torch.tensor([[0.7605, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7730, 0.5752, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7057, 0.5861, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8386, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8304, 0.7823, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7034, 0.5994, 0.5691, 0.5652, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6996, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8305, 0.7819, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8238, 0.7839, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
# Get the indices of non-zero elements
non_zero_indices = torch.nonzero(x)
print(non_zero_indices)
# Get the non-zero values
non_zero_values = x[non_zero_indices[:,0], non_zero_indices[:,1]]

unique = torch.unique(non_zero_indices[:,0] ,return_counts=True)
print("unique" , unique)
# Print the result
print(non_zero_values)
non_zero_values = torch.split(non_zero_values , tuple(unique[1]))
print("split non_zero_values" , non_zero_values)

def unpad_data( x :[Tensor] ) :
	non_zero_indices = torch.nonzero(x)
	print(non_zero_indices)
	# Get the non-zero values
	non_zero_values = x[non_zero_indices[:,0], non_zero_indices[:,1]]

	unique = torch.unique(non_zero_indices[:,0] ,return_counts=True)
	print("unique" , unique)
	# Print the result
	print(non_zero_values)
	non_zero_values = torch.split(non_zero_values , tuple(unique[1]))
	print("split non_zero_values" , non_zero_values)
	return non_zero_values

In [None]:
a = torch.tensor([[0.58 , 0.6] , [0.4] ] , )
b = torch.tensor([0.1 , 0.2] , )

c = a.repeat(2)
print(a.repeat(2))

In [None]:
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, tgt)

print(out.shape)

In [None]:
import torch

print(torch.rand(1)[0])