In [50]:
#To read in the used libraries

import torch
import torch.nn as nn
import numpy as np
import scipy.io 
import random
import math
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import os
from scipy.optimize import linear_sum_assignment
import scipy.stats as stats
import higher
os.environ['KMP_DUPLICATE_LIB_OK']='True' 
torch.cuda.is_available()

True

In [73]:
#############################################
#                Dataloader                 #
#############################################

#Read in the unlabeled multi-modal dataset
data_tensor = np.load('multi_modal_trainSet.npy',allow_pickle=True)


#Read in the labeled datasets
train_dl_BS1 = torch.load('Datasets/train_dl_BS1_300.pt')
valid_dl_BS1 = torch.load('Datasets/valid_dl_BS1_100.pt')
testing_dl_BS1 = torch.load('Datasets/testing_dl_BS1.pt')

train_dl_BS2 = torch.load('Datasets/train_dl_BS2_300.pt')
valid_dl_BS2 = torch.load('Datasets/valid_dl_BS2_100.pt')
testing_dl_BS2 = torch.load('Datasets/testing_dl_BS2.pt')

In [52]:
#############################################
#                Tool Functions             #
#############################################

#Divide a bag of data into sections
def dividing(data_array,num_array):
    if data_array.shape[0]!=sum(num_array):
        return 'wrong!'
    result = []
    summing = 0
    for item in num_array:
        result.append(data_array[summing:summing+item])
        summing = summing + item
    return result

#Metric
def metric_func(y_pred,y_true):

    validSize = y_true.shape[0]
    position_gt_x = y_true[:,0]
    position_gt_y = y_true[:,1]

    position_pred_x = y_pred[:,0]
    position_pred_y = y_pred[:,1]

    deltaX_square = torch.mul(position_gt_x-position_pred_x,position_gt_x-position_pred_x)
    deltaY_square = torch.mul(position_gt_y-position_pred_y,position_gt_y-position_pred_y)
    deltaDistance = torch.sqrt(deltaX_square + deltaY_square)
    mean_distanceError = torch.sum(deltaDistance)/validSize
    return mean_distanceError


#E step with KM algorithm
def E_step(position_pred,position_image,stdDeviation_matrix):
    num_pred = position_pred.shape[0]
    num_image = position_image.shape[0]
    position_pred_duplicate = position_pred.repeat(1,num_image).reshape(num_pred,num_image,2)
    position_image_duplicate = position_image.repeat(1,num_pred).reshape(num_image,num_pred,2).permute(1,0,2)
    root_cost_matrix = (torch.norm(position_pred_duplicate-position_image_duplicate,dim=2)/stdDeviation_matrix).numpy()
    row_ind, col_ind = linear_sum_assignment(root_cost_matrix*root_cost_matrix)
    return row_ind, col_ind, root_cost_matrix*root_cost_matrix


#E step with Nearest neighbour
def Nearest(position_pred,position_image,stdDeviation_matrix):
    num_pred = position_pred.shape[0]
    num_image = position_image.shape[0]
    position_pred_duplicate = position_pred.repeat(1,num_image).reshape(num_pred,num_image,2)
    position_image_duplicate = position_image.repeat(1,num_pred).reshape(num_image,num_pred,2).permute(1,0,2)
    root_cost_matrix = (torch.norm(position_pred_duplicate-position_image_duplicate,dim=2)/stdDeviation_matrix)
    row_ind = torch.tensor(range(num_pred))
    col_ind = torch.min(root_cost_matrix*root_cost_matrix,dim=1).indices
    return row_ind, col_ind

In [53]:
#############################################
#                Neural Networks            #
#############################################

class Residual(nn.Module):  
    def __init__(self, in_channels, out_channels):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=(1,2))
        self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=(1,2))
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.bn3 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        residual = self.conv3(residual)
        out = x + self.bn3(residual)
        return out
 

def resnet_block(in_channels, out_channels):
    blk = []
    blk.append(Residual(in_channels, out_channels))
    return nn.Sequential(*blk)

In [54]:
#############################################
#                Generate the NN            #
#############################################

torch.cuda.empty_cache()

#Whether using a pretrained model
pretraining = True

if pretraining is False:
    
    model1_BS1 = nn.Sequential()
    model1_BS1.add_module("resnet_block1", resnet_block(3, 32))
    model1_BS1.add_module("resnet_block2", resnet_block(32, 64))
    model1_BS1.add_module("flatten",nn.Flatten())
    model1_BS1.add_module("Dropout",nn.Dropout(0.5))
    model1_BS1.add_module("linear0",nn.Linear(64*16*13,32))
    model1_BS1.add_module("linear1",nn.Linear(32,128))
    model1_BS1.add_module("ReLU1",nn.ReLU())
    model1_BS1.add_module("linear2",nn.Linear(128,2))
    model1_BS1 = model1_BS1.cuda()

    model1_BS2 = nn.Sequential()
    model1_BS2.add_module("resnet_block1", resnet_block(3, 32))
    model1_BS2.add_module("resnet_block2", resnet_block(32, 64))
    model1_BS2.add_module("flatten",nn.Flatten())
    model1_BS2.add_module("Dropout",nn.Dropout(0.5))
    model1_BS2.add_module("linear0",nn.Linear(64*16*13,32))
    model1_BS2.add_module("linear1",nn.Linear(32,128))
    model1_BS2.add_module("ReLU1",nn.ReLU())
    model1_BS2.add_module("linear2",nn.Linear(128,2))
    model1_BS2 = model1_BS2.cuda()
    
else:
    #Read in the pretrained model. You may need to modify the root path of the pretrained models
    model1_BS1 = torch.load('model1_BS1_wellTrained_300data.pth')
    model1_BS2 = torch.load('model1_BS2_wellTrained_300data.pth')
    model1_BS1 = model1_BS1.cuda()
    model1_BS2 = model1_BS2.cuda()

In [55]:
#############################################
#                Training Setups            #
#############################################

learning_rate = 0.001
optimizer_BS1 = torch.optim.Adam(model1_BS1.parameters(), lr = learning_rate, weight_decay=0.001)
optimizer_BS2 = torch.optim.Adam(model1_BS2.parameters(), lr = learning_rate, weight_decay=0.001)

lambda_BS1 = lambda epoch: 0.9 ** epoch
lambda_BS2 = lambda epoch: 0.9 ** epoch

scheduler_BS1 = torch.optim.lr_scheduler.LambdaLR(optimizer_BS1,lr_lambda=lambda_BS1)
scheduler_BS2 = torch.optim.lr_scheduler.LambdaLR(optimizer_BS2,lr_lambda=lambda_BS2)

criterion_BS1 = nn.MSELoss()
criterion_BS2 = nn.MSELoss()

if torch.cuda.is_available():
    criterion_BS1 = criterion_BS1.cuda()
    criterion_BS2 = criterion_BS2.cuda()

In [56]:
#############################################
#         Initialize the weights            #
#############################################

model1_BS1.eval()
val_step = 1
    
val_output_all_bs1 = torch.tensor([])
val_labels_all_bs1 = torch.tensor([])
for val_step, (val_inputs,val_labels) in enumerate(valid_dl_BS1, 1):
    with torch.no_grad():
        val_outputs = model1_BS1(val_inputs.cuda())
        val_output_all_bs1 = torch.cat([val_output_all_bs1,val_outputs.cpu()],dim=0)
        val_labels_all_bs1 = torch.cat([val_labels_all_bs1,val_labels],dim=0)
            
distance_bs1 = torch.norm(val_labels_all_bs1,dim=1)
error_bs1 = torch.norm(val_output_all_bs1-val_labels_all_bs1,dim=1)

an_bs1 = np.polyfit(distance_bs1.numpy(),error_bs1.numpy(), 3)



model1_BS2.eval()
val_step = 1
    
val_output_all_bs2 = torch.tensor([])
val_labels_all_bs2 = torch.tensor([])
for val_step, (val_inputs,val_labels) in enumerate(valid_dl_BS2, 1):
    with torch.no_grad():
        val_outputs = model1_BS2(val_inputs.cuda())
        val_output_all_bs2 = torch.cat([val_output_all_bs2,val_outputs.cpu()],dim=0)
        val_labels_all_bs2 = torch.cat([val_labels_all_bs2,val_labels],dim=0)
            
distance_bs2 = torch.norm(val_labels_all_bs2,dim=1)
error_bs2 = torch.norm(val_output_all_bs2-val_labels_all_bs2,dim=1)

an_bs2 = np.polyfit(distance_bs2.numpy(),error_bs2.numpy(), 3)

In [None]:
#############################################
#               Model Training              #
#############################################

batchSize = 24
Steps = 10000
args = {'bs':100, 'lr':1e-3, 'n_epochs':150, 'device':'cuda:0'}
test_result_bs1 = []
test_result_bs2 = []
bs1_coor = torch.tensor([80,14])
bs2_coor = torch.tensor([140,-14])

#Whether use the meta learning
using_meta_learning = True

#ReLU or proposed
whether_using_ReLU = False

for step in range(1,Steps+1):
    #Zero grad
    optimizer_BS1.zero_grad()
    optimizer_BS2.zero_grad()
    
    #training phase
    model1_BS1.train()
    model1_BS2.train()
    #random select several bags
    idx_step = torch.randperm(len(dataset_tensor))[:batchSize]
    
    #We need to initialize a tensor for integrating all the selected bags
    input_holder_all_bs1 = torch.tensor([])
    input_holder_all_bs2 = torch.tensor([])
    
    #number of elements in each bag
    number_of_instance_bs1 = []
    number_of_instance_bs2 = []
    
    #concat all the bags for batch normalization
    for item in idx_step:
        input_holder_all_bs1 = torch.cat([input_holder_all_bs1,dataset_tensor[item][0][0]],dim=0)
        input_holder_all_bs2 = torch.cat([input_holder_all_bs2,dataset_tensor[item][1][0]],dim=0)
        
        number_of_instance_bs1.append(dataset_tensor[item][0][0].shape[0])
        number_of_instance_bs2.append(dataset_tensor[item][1][0].shape[0])
    
    #go through the NN
    output_holder_all_bs1 = model1_BS1(input_holder_all_bs1.cuda())
    output_holder_all_bs2 = model1_BS2(input_holder_all_bs2.cuda())
    
    #re-divide the outputs into bags
    output_divide_bs1 = dividing(output_holder_all_bs1.cpu().detach()+bs1_coor.repeat(output_holder_all_bs1.shape[0],1),number_of_instance_bs1)
    output_divide_bs2 = dividing(output_holder_all_bs2.cpu().detach()+bs2_coor.repeat(output_holder_all_bs2.shape[0],1),number_of_instance_bs2)
    
    #Calculating the matching through KM algorithm
    with torch.no_grad():
        label_holder_bs1 = torch.tensor([])
        label_holder_bs2 = torch.tensor([])
        
        for j in range(len(output_divide_bs1)):
            position_pred_bs1 = output_divide_bs1[j]
            position_pred_bs2 = output_divide_bs2[j]
            position_pred_all = torch.cat([position_pred_bs1,position_pred_bs2],dim=0)
        
            position_image_all = torch.cat([dataset_tensor[idx_step[j]][0][1],dataset_tensor[idx_step[j]][1][1]],dim=0)
        
            distance_relativeto_bs1 = torch.norm(position_image_all - bs1_coor.repeat(position_image_all.shape[0],1),dim=1)
            distance_relativeto_bs2 = torch.norm(position_image_all - bs2_coor.repeat(position_image_all.shape[0],1),dim=1)
        
            weighingVector_bs1 = torch.from_numpy(np.polyval(an_bs1,distance_relativeto_bs1.numpy()))
            weighingVector_bs2 = torch.from_numpy(np.polyval(an_bs2,distance_relativeto_bs2.numpy()))
        
            stdDeviation_matrix = torch.cat([weighingVector_bs1.repeat(position_pred_bs1.shape[0],1),weighingVector_bs2.repeat(position_pred_bs2.shape[0],1)],dim=0)
            row_ind, col_ind, _ = E_step(position_pred_all,position_image_all,stdDeviation_matrix)
        
            position_pred_all_copy = position_pred_all.clone()
            position_pred_all_copy[row_ind,:]=position_image_all[col_ind,:]
        
            label_partial_bs1 = position_pred_all_copy[:position_pred_bs1.shape[0],:]
            label_partial_bs2 = position_pred_all_copy[-1*position_pred_bs2.shape[0]:,:]
            
            label_holder_bs1 = torch.cat([label_holder_bs1,label_partial_bs1-bs1_coor.repeat(label_partial_bs1.shape[0],1)],dim=0)
            label_holder_bs2 = torch.cat([label_holder_bs2,label_partial_bs2-bs2_coor.repeat(label_partial_bs2.shape[0],1)],dim=0)
    
    
    
    #Check it
    if output_holder_all_bs1.shape[0]!=label_holder_bs1.shape[0] or output_holder_all_bs2.shape[0]!=label_holder_bs2.shape[0]:
        print('Wrong!')
        break
    
    #Select some clean labeled data for regularization
    idx_BS1 = torch.randperm(train_dl_BS1.dataset[:][0].shape[0])[:min(label_holder_bs1.shape[0],train_dl_BS1.dataset[:][0].shape[0])]
    input_reg_bs1 = train_dl_BS1.dataset[:][0][idx_BS1,:,:,:]
    labels_reg_bs1 = train_dl_BS1.dataset[:][1][idx_BS1,:]
    idx_BS2 = torch.randperm(train_dl_BS2.dataset[:][0].shape[0])[:min(label_holder_bs2.shape[0],train_dl_BS2.dataset[:][0].shape[0])]
    input_reg_bs2 = train_dl_BS2.dataset[:][0][idx_BS2,:,:,:]
    labels_reg_bs2 = train_dl_BS2.dataset[:][1][idx_BS2,:]
    output_reg_bs1 = model1_BS1(input_reg_bs1.cuda())
    output_reg_bs2 = model1_BS2(input_reg_bs2.cuda())
          
    #Meta-learning based re-weighting
    if using_meta_learning is False:
        loss_bs1 = 0.5*criterion_BS1(output_holder_all_bs1,label_holder_bs1.cuda()) + 0.5*criterion_BS1(output_reg_bs1,labels_reg_bs1.cuda())
        loss_bs2 = 0.5*criterion_BS2(output_holder_all_bs2,label_holder_bs2.cuda()) + 0.5*criterion_BS2(output_reg_bs2,labels_reg_bs2.cuda())
        
    else:
        #Select some clean validation data
        idx_BS1_meta = torch.randperm(valid_dl_BS1.dataset[:][0].shape[0])[:min(label_holder_bs1.shape[0],valid_dl_BS1.dataset[:][0].shape[0])]
        input_meta_bs1 = valid_dl_BS1.dataset[:][0][idx_BS1_meta,:,:,:]
        labels_meta_bs1 = valid_dl_BS1.dataset[:][1][idx_BS1_meta,:]
        idx_BS2_meta = torch.randperm(valid_dl_BS2.dataset[:][0].shape[0])[:min(label_holder_bs2.shape[0],valid_dl_BS2.dataset[:][0].shape[0])]
        input_meta_bs2 = valid_dl_BS2.dataset[:][0][idx_BS2_meta,:,:,:]
        labels_meta_bs2 = valid_dl_BS2.dataset[:][1][idx_BS2_meta,:]
        
        #Meta-learning For BS1
        with higher.innerloop_ctx(model1_BS1, optimizer_BS1) as (meta_model_bs1, meta_opt_bs1):
            meta_train_outputs_bs1 = meta_model_bs1(input_holder_all_bs1.cuda())
            meta_train_loss_bs1 = torch.square(torch.norm(meta_train_outputs_bs1-label_holder_bs1.cuda(),dim=1))
            eps_bs1 = torch.zeros(meta_train_loss_bs1.size(), requires_grad=True, device=args['device'])
            meta_train_loss_bs1 = torch.sum(eps_bs1 * meta_train_loss_bs1)
            meta_opt_bs1.step(meta_train_loss_bs1)

            meta_val_outputs_bs1 = meta_model_bs1(input_meta_bs1.cuda())
            meta_val_loss_bs1 = torch.square(torch.norm(meta_val_outputs_bs1-labels_meta_bs1.cuda(),dim=1)).mean()
            eps_grads_bs1 = torch.autograd.grad(meta_val_loss_bs1, eps_bs1)[0].detach()
        

        if whether_using_ReLU is True:
            w_tilde_bs1 = torch.clamp(-eps_grads_bs1, min=0)
            num_bs1_nonzero = torch.count_nonzero(w_tilde_bs1)
            l1_norm_bs1 = torch.sum(w_tilde_bs1)
            if l1_norm_bs1 != 0:
                w_bs1 = (w_tilde_bs1.shape[0]/num_bs1_nonzero) * w_tilde_bs1 / l1_norm_bs1
            else:
                w_bs1 = w_tilde_bs1
                
        elif whether_using_ReLU is False:
            w_tilde_bs1 = -eps_grads_bs1/(torch.max(torch.abs(eps_grads_bs1)))
            w_bs1 = torch.ones(w_tilde_bs1.size()).cuda() + w_tilde_bs1
        

        #Meta learning For BS2
        with higher.innerloop_ctx(model1_BS2, optimizer_BS2) as (meta_model_bs2, meta_opt_bs2):
            meta_train_outputs_bs2 = meta_model_bs2(input_holder_all_bs2.cuda())
            meta_train_loss_bs2 = torch.square(torch.norm(meta_train_outputs_bs2-label_holder_bs2.cuda(),dim=1))
            eps_bs2 = torch.zeros(meta_train_loss_bs2.size(), requires_grad=True, device=args['device'])
            meta_train_loss_bs2 = torch.sum(eps_bs2 * meta_train_loss_bs2)
            meta_opt_bs2.step(meta_train_loss_bs2)

            meta_val_outputs_bs2 = meta_model_bs2(input_meta_bs2.cuda())
            meta_val_loss_bs2 = torch.square(torch.norm(meta_val_outputs_bs2-labels_meta_bs2.cuda())).mean()
            eps_grads_bs2 = torch.autograd.grad(meta_val_loss_bs2, eps_bs2)[0].detach()
        
        if whether_using_ReLU is True:
            w_tilde_bs2 = torch.clamp(-eps_grads_bs2, min=0)
            num_bs2_nonzero = torch.count_nonzero(w_tilde_bs2)
            l1_norm_bs2 = torch.sum(w_tilde_bs2)
            if l1_norm_bs2 != 0:
                w_bs2 = (w_tilde_bs2.shape[0]/num_bs2_nonzero) * w_tilde_bs2 / l1_norm_bs2
            else:
                w_bs2 = w_tilde_bs2
            
        elif whether_using_ReLU is False:
            w_tilde_bs2 = -eps_grads_bs2/(torch.max(torch.abs(eps_grads_bs2)))
            w_bs2 = torch.ones(w_tilde_bs2.size()).cuda() + w_tilde_bs2
        

        
        loss_bs1_weakly = torch.square(torch.norm(output_holder_all_bs1-label_holder_bs1.cuda(),dim=1))/2
        loss_bs1_weakly = torch.mean(w_bs1 * loss_bs1_weakly)
        
        loss_bs2_weakly = torch.square(torch.norm(output_holder_all_bs2-label_holder_bs2.cuda(),dim=1))/2
        loss_bs2_weakly = torch.mean(w_bs2 * loss_bs2_weakly)
        
        loss_bs1 = 0.5*loss_bs1_weakly + 0.5 * criterion_BS1(output_reg_bs1,labels_reg_bs1.cuda())
        loss_bs2 = 0.5*loss_bs2_weakly + 0.5 * criterion_BS2(output_reg_bs2,labels_reg_bs2.cuda())
    
    #Optimizing the NNs
    loss_bs1.backward()
    loss_bs2.backward()
    
    optimizer_BS1.step()
    optimizer_BS2.step()
    
    torch.cuda.empty_cache()
    
    #Lr decay
    if step%200==0 and step>=5000:
        scheduler_BS1.step()
        scheduler_BS2.step()
       
    
    #evaluation the performance on testing datasets.
    #Note that generally, a testing dataset is not used until the model is well-trained. 
    #But the validation in our work is used to calculate the weights for each positions,
    #so we evaluate the performance with the testing datasets
    if step%100==0:
        #验证阶段
        model1_BS1.eval()
        model1_BS2.eval()

        test_metric_sum_bs1 = 0.0
        test_step_bs1 = 1
    
        test_output_all_bs1 = torch.tensor([])
        test_labels_all_bs1 = torch.tensor([])
        for test_step_bs1, (test_inputs_bs1,test_labels_bs1) in enumerate(testing_dl_BS1, 1):
            with torch.no_grad():
                test_outputs_bs1 = model1_BS1(test_inputs_bs1.cuda())
                test_metric_bs1 = metric_func(test_outputs_bs1,test_labels_bs1.cuda())
                test_output_all_bs1 = torch.cat([test_output_all_bs1,test_outputs_bs1.cpu()],dim=0)
                test_labels_all_bs1 = torch.cat([test_labels_all_bs1,test_labels_bs1],dim=0)

            test_metric_sum_bs1 += test_metric_bs1


        test_metric_sum_bs2 = 0.0
        test_step_bs2 = 1
        test_output_all_bs2 = torch.tensor([])
        test_labels_all_bs2 = torch.tensor([])
        for test_step_bs2, (test_inputs_bs2,test_labels_bs2) in enumerate(testing_dl_BS2, 1):
            with torch.no_grad():
                test_outputs_bs2 = model1_BS2(test_inputs_bs2.cuda())
                test_metric_bs2 = metric_func(test_outputs_bs2,test_labels_bs2.cuda())
                test_output_all_bs2 = torch.cat([test_output_all_bs2,test_outputs_bs2.cpu()],dim=0)
                test_labels_all_bs2 = torch.cat([test_labels_all_bs2,test_labels_bs2],dim=0)
            test_metric_sum_bs2 += test_metric_bs2
            
        info = (step,test_metric_sum_bs1/test_step_bs1,test_metric_sum_bs2/test_step_bs2)

        print(("\nSTEP = %d,test_metric_bs1 = %.3f,test_metric_bs2 = %.3f\n") 
              %info)
        
        test_result_bs1.append((test_metric_sum_bs1/test_step_bs1).item())
        test_result_bs2.append((test_metric_sum_bs2/test_step_bs2).item())
    else:
        print('step {},loss_bs1 {},loss_bs2 {}'.format(step, loss_bs1.item(), loss_bs2.item()))
        
    
    if step%50==0:
        #Re-calculate the sigma function with polyfit
        model1_BS1.eval()
        model1_BS2.eval()
    
        val_output_all_bs1 = torch.tensor([])
        val_labels_all_bs1 = torch.tensor([])
        for val_step_bs1, (val_inputs_bs1,val_labels_bs1) in enumerate(valid_dl_BS1, 1):
            with torch.no_grad():
                val_outputs_bs1 = model1_BS1(val_inputs_bs1.cuda())
                val_output_all_bs1 = torch.cat([val_output_all_bs1,val_outputs_bs1.cpu()],dim=0)
                val_labels_all_bs1 = torch.cat([val_labels_all_bs1,val_labels_bs1],dim=0)
            
        distance_bs1 = torch.norm(val_labels_all_bs1,dim=1)
        error_bs1 = torch.norm(val_output_all_bs1-val_labels_all_bs1,dim=1)
        an_bs1 = np.polyfit(distance_bs1.numpy(),error_bs1.numpy(), 3)
        
        val_output_all_bs2 = torch.tensor([])
        val_labels_all_bs2 = torch.tensor([])
        for val_step_bs2, (val_inputs_bs2,val_labels_bs2) in enumerate(valid_dl_BS2, 1):
            with torch.no_grad():
                val_outputs_bs2 = model1_BS2(val_inputs_bs2.cuda())
                val_output_all_bs2 = torch.cat([val_output_all_bs2,val_outputs_bs2.cpu()],dim=0)
                val_labels_all_bs2 = torch.cat([val_labels_all_bs2,val_labels_bs2],dim=0)
            
        distance_bs2 = torch.norm(val_labels_all_bs2,dim=1)
        error_bs2 = torch.norm(val_output_all_bs2-val_labels_all_bs2,dim=1)
        an_bs2 = np.polyfit(distance_bs2.numpy(),error_bs2.numpy(), 3)
        
    torch.cuda.empty_cache()  
    

In [189]:
#############################################
#               Model Testing               #
#############################################

#Read in multi-modal testing dataset
dataset_tensor_testing = np.load('multi_modal_testSet.npy',allow_pickle=True)

#Coordinates of BS
bs1_coor = torch.tensor([80,14])
bs2_coor = torch.tensor([140,-14])

#Initialize sum of metrics
test_metric_sum_bs1_withImage_collective = 0.0
test_metric_sum_bs1_noImage = 0.0
test_metric_sum_bs1_withImage_collective_nearest = 0.0

test_metric_sum_bs2_withImage_collective = 0.0
test_metric_sum_bs2_noImage = 0.0
test_metric_sum_bs2_withImage_collective_nearest = 0.0
test_step = 1

#Switch the mode
model1_BS1.eval()
model1_BS2.eval()

#intialize tensor for recording the testing results
test_labels_bs1_holder = torch.tensor([])
test_labels_bs2_holder = torch.tensor([])
test_output_bs1_holder = torch.tensor([])
test_output_bs2_holder = torch.tensor([])
test_calibrated_bs1_holder = torch.tensor([])
test_calibrated_bs2_holder = torch.tensor([])
test_calibrated_bs1_holder_collective_nearest = torch.tensor([])
test_calibrated_bs2_holder_collective_nearest = torch.tensor([])

sum_num_of_chosen_bs1 = 0
sum_num_of_chosen_bs2 = 0

for test_step, ((test_inputs_bs1,test_labels_bs1,test_labels_bs1_chosen),(test_inputs_bs2,test_labels_bs2,test_labels_bs2_chosen)) in enumerate(dataset_tensor_testing, 1):
    
    with torch.no_grad():
        test_outputs_bs1 = model1_BS1(test_inputs_bs1.cuda())
        test_labels_bs1_relative = test_labels_bs1 - bs1_coor.repeat(test_labels_bs1.shape[0],1)
        test_metric_bs1_noImage = metric_func(test_outputs_bs1,test_labels_bs1_relative.cuda())
        test_metric_sum_bs1_noImage += test_metric_bs1_noImage
        test_labels_bs1_holder = torch.cat([test_labels_bs1_holder,test_labels_bs1_relative],dim=0)
        test_output_bs1_holder = torch.cat([test_output_bs1_holder,test_outputs_bs1.detach().cpu()],dim=0)
            
        test_outputs_bs2 = model1_BS2(test_inputs_bs2.cuda())
        test_labels_bs2_relative = test_labels_bs2 - bs2_coor.repeat(test_labels_bs2.shape[0],1)
        test_metric_bs2_noImage = metric_func(test_outputs_bs2,test_labels_bs2_relative.cuda())
        test_metric_sum_bs2_noImage += test_metric_bs2_noImage
        test_labels_bs2_holder = torch.cat([test_labels_bs2_holder,test_labels_bs2_relative],dim=0)
        test_output_bs2_holder = torch.cat([test_output_bs2_holder,test_outputs_bs2.detach().cpu()],dim=0)
            
    
            
        sum_num_of_chosen_bs1 = sum_num_of_chosen_bs1 + len(test_labels_bs1_chosen)
        sum_num_of_chosen_bs2 = sum_num_of_chosen_bs2 + len(test_labels_bs2_chosen)

            
        test_labels_chosen = torch.cat([test_labels_bs1_chosen,test_labels_bs2_chosen],dim=0)
        distance_relativeto_bs1 = torch.norm(test_labels_chosen - bs1_coor.repeat(test_labels_chosen.shape[0],1),dim=1)
        distance_relativeto_bs2 = torch.norm(test_labels_chosen - bs2_coor.repeat(test_labels_chosen.shape[0],1),dim=1)
            
        weighingVector_bs1 = torch.from_numpy(np.polyval(an_bs1,distance_relativeto_bs1.numpy()))
        weighingVector_bs2 = torch.from_numpy(np.polyval(an_bs2,distance_relativeto_bs2.numpy()))
        
        weighingMatrix_bs1 = weighingVector_bs1.repeat(test_inputs_bs1.shape[0],1)
        weighingMatrix_bs2 = weighingVector_bs2.repeat(test_inputs_bs2.shape[0],1)
        
        
        
        #求解一起匹配的解
        test_outputs_all = torch.cat([test_outputs_bs1.cpu()+bs1_coor.repeat(test_outputs_bs1.shape[0],1),test_outputs_bs2.cpu()+bs2_coor.repeat(test_outputs_bs2.shape[0],1)],dim=0)
        weighingMatrix_all = torch.cat([weighingMatrix_bs1,weighingMatrix_bs2],dim=0)
        row_ind_all, col_ind_all,root_matrix = E_step(test_outputs_all,test_labels_chosen,weighingMatrix_all)
        
        test_outputs_all_calibrate = test_outputs_all.clone()
        test_outputs_all_calibrate[row_ind_all,:] = test_labels_chosen[col_ind_all,:]
        
        test_outputs_bs1_calibrate_collective = test_outputs_all_calibrate[:test_outputs_bs1.shape[0],:] - bs1_coor.repeat(test_outputs_bs1.shape[0],1)
        test_outputs_bs2_calibrate_collective = test_outputs_all_calibrate[-1*test_outputs_bs2.shape[0]:,:] - bs2_coor.repeat(test_outputs_bs2.shape[0],1)
        
        test_metric_bs1_withImage_collective = metric_func(test_labels_bs1_relative,test_outputs_bs1_calibrate_collective)
        test_metric_sum_bs1_withImage_collective += test_metric_bs1_withImage_collective
        test_calibrated_bs1_holder = torch.cat([test_calibrated_bs1_holder,test_outputs_bs1_calibrate_collective],dim=0)
        
        test_metric_bs2_withImage_collective = metric_func(test_labels_bs2_relative,test_outputs_bs2_calibrate_collective)
        test_metric_sum_bs2_withImage_collective += test_metric_bs2_withImage_collective
        test_calibrated_bs2_holder = torch.cat([test_calibrated_bs2_holder,test_outputs_bs2_calibrate_collective],dim=0)
        
        row_ind_all_nearest, col_ind_all_nearest = Nearest(test_outputs_all,test_labels_chosen,weighingMatrix_all)
        test_outputs_all_calibrate_nearest = test_outputs_all.clone()
        test_outputs_all_calibrate_nearest[row_ind_all_nearest,:] = test_labels_chosen[col_ind_all_nearest,:]
        test_outputs_bs1_calibrate_collective_nearest = test_outputs_all_calibrate_nearest[:test_outputs_bs1.shape[0],:] - bs1_coor.repeat(test_outputs_bs1.shape[0],1)
        test_outputs_bs2_calibrate_collective_nearest = test_outputs_all_calibrate_nearest[-1*test_outputs_bs2.shape[0]:,:] - bs2_coor.repeat(test_outputs_bs2.shape[0],1)
        test_metric_bs1_withImage_collective_nearest = metric_func(test_labels_bs1_relative,test_outputs_bs1_calibrate_collective_nearest)
        test_metric_sum_bs1_withImage_collective_nearest += test_metric_bs1_withImage_collective_nearest
        test_metric_bs2_withImage_collective_nearest = metric_func(test_labels_bs2_relative,test_outputs_bs2_calibrate_collective_nearest)
        test_metric_sum_bs2_withImage_collective_nearest += test_metric_bs2_withImage_collective_nearest
        
        test_calibrated_bs1_holder_collective_nearest = torch.cat([test_calibrated_bs1_holder_collective_nearest,test_outputs_bs1_calibrate_collective_nearest],dim=0)
        test_calibrated_bs2_holder_collective_nearest = torch.cat([test_calibrated_bs2_holder_collective_nearest,test_outputs_bs2_calibrate_collective_nearest],dim=0)
        
    
#Calculating all the metrics
distance_limit = 5
test_error_Directly_bs1 = torch.norm(test_output_bs1_holder - test_labels_bs1_holder,dim=1)
test_error_KM_bs1 = torch.norm(test_calibrated_bs1_holder - test_labels_bs1_holder,dim=1)
test_error_Nearest_bs1 = torch.norm(test_calibrated_bs1_holder_collective_nearest - test_labels_bs1_holder,dim=1)
test_error_Nearest_bs1_calibrate = test_error_Nearest_bs1.clone()
for i in range(len(test_error_Nearest_bs1_calibrate)):
    if test_error_Nearest_bs1_calibrate[i]>=distance_limit:
        test_error_Nearest_bs1_calibrate[i]=test_error_Directly_bs1[i]

test_error_Directly_bs2 = torch.norm(test_output_bs2_holder - test_labels_bs2_holder,dim=1)
test_error_KM_bs2 = torch.norm(test_calibrated_bs2_holder - test_labels_bs2_holder,dim=1)
test_error_Nearest_bs2 = torch.norm(test_calibrated_bs2_holder_collective_nearest - test_labels_bs2_holder,dim=1)
test_error_Nearest_bs2_calibrate = test_error_Nearest_bs2.clone()
for i in range(len(test_error_Nearest_bs2_calibrate)):
    if test_error_Nearest_bs2_calibrate[i]>=distance_limit:
        test_error_Nearest_bs2_calibrate[i]=test_error_Directly_bs2[i]
    
#Metrics of BS1
test_metric_mean_bs1_withImage_collective = torch.mean(test_error_KM_bs1)
test_metric_mean_bs1_noImage = torch.mean(test_error_Directly_bs1)
test_metric_mean_bs1_withImage_collective_nearest = torch.mean(test_error_Nearest_bs1)
test_metric_mean_bs1_withImage_collective_nearest_calibrate = torch.mean(test_error_Nearest_bs1_calibrate)

#Metrics of BS2
test_metric_mean_bs2_withImage_collective = torch.mean(test_error_KM_bs2)
test_metric_mean_bs2_noImage = torch.mean(test_error_Directly_bs2)
test_metric_mean_bs2_withImage_collective_nearest = torch.mean(test_error_Nearest_bs2)
test_metric_mean_bs2_withImage_collective_nearest_calibrate = torch.mean(test_error_Nearest_bs2_calibrate)