# Imports

In [1]:
import transformers
from transformers import CLIPConfig, CLIPModel, CLIPProcessor, CLIPImageProcessor, CLIPTokenizerFast
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
import random
import math
import scipy.io as sio
import nibabel as nib
from pathlib import Path
from gensim.models import Word2Vec
import re

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# Load word and fMRI data

### Load 3d fMRI data

In [3]:
NUM_SUBJS = 8
subjects_fmri = [] #stores all 8 subject fmri np arrays

fMRI_folder = Path('./doi_10.5061_dryad.gt413__v1')
assert fMRI_folder.exists(), f"Foldder: {fMRI_folder} does not exist."

for subj_id in range(8):
#     fmri_file_name = str(subj_id) + '_masked_2d.npy'
#     fmri = np.load(fMRI_folder / fmri_file_name)  
    fmri_file_name = str(subj_id) + '_smooth_nifti_4d.nii'
    fmri = nib.load(fMRI_folder / fmri_file_name)
    fmri = np.array(fmri.dataobj)
    assert isinstance(fmri, np.ndarray), f"Imported fmri_scan for subject {subj_id} is not of type numpy.ndarray"
    assert(fmri.ndim) == 4, f"Imported fmri_scan for subject {subj_id} is not 4 dimensional"
    subjects_fmri.append(fmri)

# Load words

In [4]:
feature_matrix = np.zeros((5176,195)) #stores the feature vectors as a row for each word
feature_names = [] #stores the names of all features in order
feature_types = {} #stores the types of features and all the names of the features for each type

features = sio.loadmat(fMRI_folder / 'story_features.mat')
feature_count = 0
for feature_type in features['features'][0]:
    feature_types[feature_type[0][0]] = []
    if isinstance(feature_type[1][0], str):
        feature_types[feature_type[0][0]].append(feature_type[1][0])
        feature_names.append(feature_type[1][0])
    else:
        for feature in feature_type[1][0]:
            feature_types[feature_type[0][0]].append(feature[0])
            feature_names.append(feature[0])
    feature_matrix[:, feature_count:feature_count+feature_type[2].shape[1]] = feature_type[2] #adds the (5176xN) feature values to the feature matrix for the current feature group
    feature_count += feature_type[2].shape[1]

In [5]:
words_info = [] #stores tuples of (word, time, features) sorted by time appeared

mat_file = fMRI_folder / 'subject_1.mat' #only looks at the first subject file, somewhere it said all the timings were the same so this should be safe
mat_contents = sio.loadmat(mat_file)
for count, row in enumerate(mat_contents['words'][0]):
    word_value = row[0][0][0][0]
    time = row[1][0][0]
    word_tuple = (word_value, time, feature_matrix[count,:])
    words_info.append(word_tuple)

In [6]:
chapter_nine_text = ""
for row in mat_contents['words'][0]:
    chapter_nine_text += row[0][0][0][0] + " "
print(chapter_nine_text)

Harry had never believed he would meet a boy he hated more than Dudley, but that was before he met Draco Malfoy. Still, first-year Gryffindors only had Potions with the Slytherins, so they didn't have to put up with Malfoy much. Or at least, they didn't until they spotted a notice pinned up in the Gryffindor common room that made them all groan. Flying lessons would be starting on Thursday -- and Gryffindor and Slytherin would be learning together. + "Typical," said Harry darkly. "Just what I always wanted. To make a fool of myself on a broomstick in front of Malfoy." + He had been looking forward to learning to fly more than anything else. "You don't know that you'll make a fool of yourself," said Ron reasonably. "Anyway, I know Malfoy's always going on about how good he is at Quidditch, but I bet that's all talk." + Malfoy certainly did talk about flying a lot. He complained loudly about first years never getting on the House Quidditch teams and told long, boastful stories that alway

### Align fMRI scans with sets of 4 words

In [7]:
subjects_samples = [[] for i in range(NUM_SUBJS)] #stores lists of all the samples for each subject

word_count = 0
while word_count < len(words_info) - 8:
    #gets the 4 input words, and the 4 consecutive words while verifying they were read in sequence
    scan_words = []
    start_time = words_info[word_count][1]
    in_sequence = True #tracks if the words are in sequence or not
    for i in range(8):
        word_info = words_info[word_count + i]
        if word_info[1] != start_time + 0.5*i:
            #if some of the words are not in sequence, skip forward 1 word after innter loop
            in_sequence = False
        scan_words.append(word_info[0])
    if not in_sequence:
        word_count +=1
        continue
    fmri_time = start_time + 2 #effect of reading words is assumed to start 2 seconds after and end 8 seconds after
    fmri_index = fmri_time//2 #since a scan happens every two seconds, the index is the time divided by 2
    if not isinstance(fmri_index, np.int32):
        #if the first word is not aligned with the fmri scan (i.e. its not the first word in a TR)
        word_count += 1
        continue
    for count, subject in enumerate(subjects_fmri):
        #adds tuple of (fmri_scan, four words)
        subjects_samples[count].append((subject[:,:,:,fmri_index+2], scan_words[0:4]))
    print("Created sample:")
    print("\tScan time:", str(start_time))
    print("\tInput words:", str(scan_words[0:4]))
    #if successful, skip forward to the next set of 4 words
    word_count += 4

print("Total number of samples:", str(len(subjects_samples[0])))

Created sample:
	Scan time: 20
	Input words: ['Harry', 'had', 'never', 'believed']
Created sample:
	Scan time: 22
	Input words: ['he', 'would', 'meet', 'a']
Created sample:
	Scan time: 24
	Input words: ['boy', 'he', 'hated', 'more']
Created sample:
	Scan time: 26
	Input words: ['than', 'Dudley,', 'but', 'that']
Created sample:
	Scan time: 28
	Input words: ['was', 'before', 'he', 'met']
Created sample:
	Scan time: 30
	Input words: ['Draco', 'Malfoy.', 'Still,', 'first-year']
Created sample:
	Scan time: 32
	Input words: ['Gryffindors', 'only', 'had', 'Potions']
Created sample:
	Scan time: 34
	Input words: ['with', 'the', 'Slytherins,', 'so']
Created sample:
	Scan time: 36
	Input words: ['they', "didn't", 'have', 'to']
Created sample:
	Scan time: 38
	Input words: ['put', 'up', 'with', 'Malfoy']
Created sample:
	Scan time: 40
	Input words: ['much.', 'Or', 'at', 'least,']
Created sample:
	Scan time: 42
	Input words: ['they', "didn't", 'until', 'they']
Created sample:
	Scan time: 44
	Input w

	Input words: ['contact.', "What's", 'the', 'matter?']
Created sample:
	Scan time: 1462
	Input words: ['Never', 'heard', 'of', 'a']
Created sample:
	Scan time: 1464
	Input words: ["wizard's", 'duel', 'before,', 'I']
Created sample:
	Scan time: 1466
	Input words: ['suppose?"', '+', '"Of', 'course']
Created sample:
	Scan time: 1468
	Input words: ['he', 'has,"', 'said', 'Ron,']
Created sample:
	Scan time: 1470
	Input words: ['wheeling', 'around.', '"I\'m', 'his']
Created sample:
	Scan time: 1472
	Input words: ['second,', "who's", 'yours?"', '+']
Created sample:
	Scan time: 1474
	Input words: ['Malfoy', 'looked', 'at', 'Crabbe']
Created sample:
	Scan time: 1476
	Input words: ['and', 'Goyle,', 'sizing', 'them']
Created sample:
	Scan time: 1478
	Input words: ['up.', '+', '"Crabbe,"', 'he']
Created sample:
	Scan time: 1480
	Input words: ['said.', '"Midnight', 'all', 'right?']
Created sample:
	Scan time: 1482
	Input words: ["We'll", 'meet', 'you', 'in']
Created sample:
	Scan time: 1484
	Input 

# Text encoding

In [22]:
#https://radimrehurek.com/gensim/models/word2vec.html

In [8]:
with open("./J. K. Rowling - Harry Potter 1 - Sorcerer's Stone.txt") as file:
    data = file.read().replace('\n',' ')

sentence_tokens = []
#split by all periods not preceded by Mr or Mrs
for sentence in re.split("(?<!Mr)(?<!Mrs)\.($| )", data + " " + chapter_nine_text):
    sentence_tokens.append(sentence.split(" "))
print(sentence_tokens)



In [9]:
word_vec_encoder = Word2Vec(sentences=sentence_tokens, vector_size=100, window=4, min_count=0)

In [98]:
vector = model.wv['Harry']
print(vector)
sims = model.wv.most_similar('Harry', topn=10)
print(sims)

[-0.1592762   1.0358157   0.3998281   0.65896773 -0.14341635 -1.1957458
  0.5335763   1.7380077  -0.766566   -1.0356336  -0.07566728 -1.1751062
  0.62940866  0.7446685   0.7535447  -0.9587012   0.8585927  -0.89001256
  0.47481298 -1.7392517   0.6873861   0.32204965  0.9817534   0.1653276
 -0.32936937  0.02548517 -0.7163936  -0.16398834 -0.2911528   0.2216895
  0.869124   -0.5320245   0.6003812  -0.835552   -0.5567736   0.8724841
 -0.03029839 -0.2978085  -0.02549259 -1.0850229   0.36631638 -0.8160902
 -0.8166265  -0.4492721   0.4768799  -0.79738665 -0.5776063  -0.12688905
  0.62858564  0.05322775  0.2809282  -0.36204645 -0.38997352 -0.03191327
 -0.06836174  0.18175007  0.15393417 -0.2835372  -0.5398434   0.4702249
 -0.497334    0.07764219  0.72429174  0.21869192 -1.058616    0.64307624
  0.26159742  0.47916225 -0.8020325   0.76346177 -0.2581009  -0.115255
  0.920842    0.13798098  0.83246446  0.07702633  0.33448848  0.49757266
 -0.69999015 -0.22959568 -0.7526406  -0.38327488 -0.01637377

In [99]:
vec = model.wv['+']
print(vec)
sims = model.wv.most_similar('+', topn=10)
print(sims)

[-0.11667868  0.8429349   0.32117662  0.53995454 -0.12214363 -0.95944095
  0.4410393   1.4363979  -0.6178511  -0.8422581  -0.05303952 -0.9375535
  0.45371264  0.6066769   0.5892294  -0.7796291   0.6357723  -0.73351073
  0.36776012 -1.4180871   0.5679081   0.28288183  0.79845035  0.17510816
 -0.24930765  0.00582835 -0.593178   -0.12943596 -0.23465817  0.16131814
  0.6633559  -0.41084656  0.49868703 -0.6812801  -0.44114894  0.7041934
 -0.01215455 -0.26742408 -0.04254748 -0.90935147  0.25611266 -0.6575219
 -0.64853936 -0.364395    0.40166733 -0.6193323  -0.47563508 -0.09127057
  0.49026     0.07301917  0.25201496 -0.27739495 -0.35510117 -0.04099831
 -0.05254434  0.11012886  0.13631187 -0.2566088  -0.46786755  0.3700717
 -0.42504042  0.0688227   0.58254904  0.11905532 -0.84944856  0.5287194
  0.251611    0.40147305 -0.64631826  0.6357761  -0.20978577 -0.06242826
  0.7103947   0.09084885  0.6857637   0.07578729  0.3179003   0.3954695
 -0.56217504 -0.18856771 -0.59203774 -0.28812164 -0.01278

In [7]:
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')
text_encoder = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased').to(device)

Using cache found in C:\Users\Portable/.cache\torch\hub\huggingface_pytorch-transformers_main
Using cache found in C:\Users\Portable/.cache\torch\hub\huggingface_pytorch-transformers_main


In [8]:

test_text = "Draco Malfoy. Still, first-year"

# Tokenized input with special tokens around it (for BERT: [CLS] at the beginning and [SEP] at the end)
indexed_tokens = tokenizer.encode(test_text, add_special_tokens=False)
print(indexed_tokens)
print(len(indexed_tokens))

# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
#segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens]).to(device)
print(tokens_tensor)
print(tokens_tensor.shape)

with torch.no_grad():
    print(text_encoder(tokens_tensor)['last_hidden_state'])
    print(text_encoder(tokens_tensor)['last_hidden_state'].shape)

[1987, 17312, 18880, 14467, 1183, 119, 4209, 117, 1148, 118, 1214]
11
tensor([[ 1987, 17312, 18880, 14467,  1183,   119,  4209,   117,  1148,   118,
          1214]], device='cuda:0')
torch.Size([1, 11])
tensor([[[ 0.2427,  0.2542,  0.3174,  ..., -0.5812,  0.2457, -0.0402],
         [-0.0625,  0.1290,  0.2033,  ..., -0.5730,  0.4792,  0.1508],
         [ 0.1854, -0.0772,  0.3786,  ..., -0.3806,  0.3886, -0.0047],
         ...,
         [-0.1498, -0.2543, -0.0726,  ..., -0.7107, -0.1733,  0.4407],
         [ 0.4777, -0.1650,  0.0273,  ..., -0.4016,  0.0094, -0.1742],
         [ 0.7020,  0.1362,  0.3621,  ..., -0.2534, -0.1503, -0.0344]]],
       device='cuda:0')
torch.Size([1, 11, 768])


### Replace words with tokenized versions

In [9]:
subject_0_tokenized = []
for sample in subjects_samples[0]:
    words = " ".join(sample[1])
    print(words)
    tokens = tokenizer.encode(words, add_special_tokens=False)
    print(tokens)
    subject_0_tokenized.append((sample[0], tokens))

Harry had never believed
[3466, 1125, 1309, 2475]
he would meet a
[1119, 1156, 2283, 170]
boy he hated more
[2298, 1119, 5687, 1167]
than Dudley, but that
[1190, 12840, 117, 1133, 1115]
was before he met
[1108, 1196, 1119, 1899]
Draco Malfoy. Still, first-year
[1987, 17312, 18880, 14467, 1183, 119, 4209, 117, 1148, 118, 1214]
Gryffindors only had Potions
[144, 1616, 16274, 8380, 1116, 1178, 1125, 18959, 6126]
with the Slytherins, so
[1114, 1103, 156, 1193, 8420, 4935, 117, 1177]
they didn't have to
[1152, 1238, 112, 189, 1138, 1106]
put up with Malfoy
[1508, 1146, 1114, 18880, 14467, 1183]
much. Or at least,
[1277, 119, 2926, 1120, 1655, 117]
they didn't until they
[1152, 1238, 112, 189, 1235, 1152]
spotted a notice pinned
[6910, 170, 4430, 11973]
up in the Gryffindor
[1146, 1107, 1103, 144, 1616, 16274, 8380]
common room that made
[1887, 1395, 1115, 1189]
them all groan. Flying
[1172, 1155, 13344, 119, 7769]
lessons would be starting
[8497, 1156, 1129, 2547]
on Thursday -- and
[1113, 

# Image encoding

In [23]:
#https://www.kaggle.com/code/schmoyote/simple-cnn-architecture-for-image-classification

In [10]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool3d(3)
        self.bn = nn.BatchNorm3d(1)
        self.conv1 = nn.Conv3d(1, 32, kernel_size = 3)
        self.conv2 = nn.Conv3d(32, 64, kernel_size = 3)
        self.conv3 = nn.Conv3d(64, 64, kernel_size = 3)
        
    def forward(self, x):
        out = self.bn(x)
        out = self.relu(self.conv1(x))
        out = self.max_pool(out)
        out = self.relu(self.conv2(out))
        out = self.max_pool(out)
        out = self.relu(self.conv3(out))
        out = out.view(out.size(0), -1)
        return out

In [11]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv3d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1))
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

In [12]:
#https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes = 10):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Sequential(
                        nn.Conv3d(1, 64, kernel_size = 7, stride = 2, padding = 3),
                        nn.ReLU())
        self.maxpool = nn.MaxPool3d(kernel_size = 3, stride = 2, padding = 1)
        self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
        self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
        self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
        self.layer3 = self._make_layer(block, 512, layers[3], stride = 2)
        self.avgpool = nn.AvgPool3d(2, stride=1) #i changed 7 to 2
        self.fc = nn.Linear(512, num_classes)
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            
            downsample = nn.Sequential(
                nn.Conv3d(self.inplanes, planes, kernel_size=1, stride=stride),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        #print(x.shape)
        x = self.conv1(x)
        #print(x.shape)
        x = self.maxpool(x)
        #print(x.shape)
        x = self.layer0(x)
        #print(x.shape)
        x = self.layer1(x)
        #print(x.shape)
        x = self.layer2(x)
        #print(x.shape)
        x = self.layer3(x)
        #print(x.shape)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
#         x = self.fc(x)

        return x

In [13]:
image_encoder = ResNet(ResidualBlock, [3, 4, 6, 3], num_classes=512).to(device)

# CLIP from scratch

In [14]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

In [32]:
#clip model which simultaneously trains the text and image encoder as well as learning weights 
#for projecting encodings of both text and image to the same latent space

#model is based on psuedocode from original clip paper
class CLIP(nn.Module):
    def __init__(self, textEncoder, text_output_shape, imageEncoder, image_output_shape, embedding_dim):
        super(CLIP, self).__init__()
        #self.text_encoder = textEncoder(100, text_output_shape)
        self.text_output_shape = text_output_shape
        self.text_encoder = textEncoder
        self.text_proj = nn.Linear(text_output_shape, embedding_dim).to(device)
        #self.image_encoder = imageEncoder(1000, image_output_shape)
        self.image_output_shape = image_output_shape
        self.image_encoder = imageEncoder
        self.image_proj = nn.Linear(image_output_shape, embedding_dim).to(device)
    
    def forward(self, text, image):
        #gets encodings of text and images
#         text_features = self.text_encoder(text)
#         text_features = torch.zeros((len(text), self.text_output_shape))
        text_features = []
        for sample in text:
            sample_features = []
            for word in sample:
                stripped_word = re.sub("(?<!Mr)(?<!Mrs)\.($| )", "", word)
                sample_features += list(self.text_encoder.wv[stripped_word])
            text_features.append(sample_features)
        text_features = torch.tensor(text_features).float().to(device)
#         encodings = self.text_encoder(torch.tensor([sample]).to(device))['last_hidden_state']
#         text_features[idx] = torch.squeeze(torch.sum(encodings, dim=1), dim=0)
        image_features = self.image_encoder(torch.unsqueeze(torch.tensor(image).to(device), 1).float())
#         image_features = self.image_encoder(torch.reshape(torch.tensor(image).float().to(device), (image.shape[0], -1)))
#         image_features = torch.reshape(torch.tensor(image).float().to(device), (image.shape[0], -1))
    
        #projects text and images into latent space
        text_embed = self.text_proj(text_features)
        text_embed_norm = text_embed / text_embed.norm(dim=1, keepdim=True)
        image_embed = self.image_proj(image_features)
        image_embed_norm = image_embed/ image_embed.norm(dim=1, keepdim=True)
        
        #return text_embed_norm, image_embed_norm
        return text_embed, image_embed

In [43]:
#trains the clip model from scratch
def train_clip(model, text_samples, image_samples, batch_size=10, num_epochs=100, lr=1e-3, temp=0.07):
    #optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.2)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
#     scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 1000)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        print("\tLearning_rate:",optimizer.param_groups[0]["lr"])
        epoch_loss = 0
        epoch_correct = 0
        epoch_total = 0
#         text_samples, image_samples = unison_shuffled_copies(text_samples, image_samples)
        zipped = list(zip(image_samples, text_samples))
        random.shuffle(zipped)
        image_samples, text_samples = zip(*zipped)
        image_samples = np.array(image_samples)
        #print(image_samples.shape)
        for batch in range(math.floor(image_samples.shape[0]/batch_size)):
#             print("\tBatch:", batch, "/", math.floor(image_samples.shape[0]/batch_size))
            optimizer.zero_grad()
            #gets embeddings for text and image batches
            start_idx = batch*batch_size
            end_idx = (batch+1)*batch_size
            text_batch, image_batch = text_samples[start_idx:end_idx], image_samples[start_idx:end_idx]
            text_embed, image_embed = model(text_batch, image_batch)
            #computes pairwise cosine similarity between text and image embeddings
            #logits = torch.matmul(text_embed, image_embed.t()) * math.exp(temp)
            logits = torch.cdist(text_embed, image_embed) #euclidean distances
            print(logits)
            #symmetric loss function
            labels = torch.arange(batch_size).to(device)
            loss_text = loss_fn(logits, labels)
            loss_image = loss_fn(logits.t(), labels)
            loss = (loss_text + loss_image)/2
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()
            #print(optimizer.state_dict)
            #compute accuracy
#             probs = torch.softmax(logits, dim=0)
#             print(probs)
#             print(torch.max(probs))
            winners = torch.argmax(logits, dim=0)
            print(winners)
            print(labels)
            corrects = (winners == labels)
            total_correct = corrects.sum().float().item()
            epoch_correct += total_correct
            epoch_total += batch_size
#         scheduler.step()
        print("Epoch:", epoch, "Training Loss:", epoch_loss, "Training Accuracy:", epoch_correct/epoch_total)
#         for param in model.parameters():
#             print(torch.absolute(param.grad.data).sum())
#         for name, param in model.named_parameters():
#             if param.requires_grad:
#                 print(param.grad)
#                 print(name, param.data)

### Tests out CLIP with very basic data and encoders

In [17]:
class TestModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(TestModel, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
    
    def forward(self, samples):
        return self.linear(samples)

In [17]:
#trains a clip model on random input data

text = torch.rand((50, 100))
images = torch.rand((50, 1000))

text_encoder = TestModel
image_encoder = TestModel

temp=0.07
clip_model = CLIP(TestModel(100, 60), 60, TestModel(1000,200), 200, 50)

torch.autograd.set_detect_anomaly(True)
train_clip(clip_model, text, images, num_epochs=5, temp=temp)

	Learning_rate: 0.001


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [51]:
#assesses the performance of the trained clip model on the entire training set
loss_fn = nn.CrossEntropyLoss()

text_embed, image_embed = clip_model(text, images)
#computes pairwise cosine similarity between text and image embeddings
logits = torch.matmul(text_embed, image_embed.t()) * math.exp(temp)
#symmetric loss function
labels = torch.arange(test_text.shape[0])
loss_text = loss_fn(logits, labels)
loss_image = loss_fn(logits.t(), labels)
loss = (loss_text + loss_image)/2
#compute accuracy
probs = torch.softmax(logits, dim=0)
winners = torch.argmax(probs, dim=0)
corrects = (winners == labels)
total_correct = corrects.sum().float().item()
print("Training Loss:", loss.item(), "Training Accuracy:", total_correct/test_text.shape[0])

Training Loss: 3.085871696472168 Training Accuracy: 1.0


In [52]:
#assesses the performance of the trained clip model on a test set

loss_fn = nn.CrossEntropyLoss()

test_text = torch.rand((50, 100))
test_images = torch.rand((50, 1000))

text_embed, image_embed = clip_model(test_text, test_images)
#computes pairwise cosine similarity between text and image embeddings
logits = torch.matmul(text_embed, image_embed.t()) * math.exp(temp)
#symmetric loss function
labels = torch.arange(test_text.shape[0])
loss_text = loss_fn(logits, labels)
loss_image = loss_fn(logits.t(), labels)
loss = (loss_text + loss_image)/2
#compute accuracy
probs = torch.softmax(logits, dim=0)
winners = torch.argmax(probs, dim=0)
corrects = (winners == labels)
total_correct = corrects.sum().float().item()
print("Testing Loss:", loss.item(), "Testing Accuracy:", total_correct/test_text.shape[0])

Testing Loss: 3.9154810905456543 Testing Accuracy: 0.08


### Train using real fMRI and text data

In [17]:
def split_samples(samples):
    images = np.zeros([len(samples)] + list(samples[0][0].shape))
    text = []
    for idx, sample in enumerate(samples):
        images[idx] = sample[0]
        text.append(sample[1])
    return images, text

In [19]:
#make sure samples are all tensors and properly formatted before inputting to torch model
train_images = torch.zeros((len(subjects_samples[0]), subjects_samples[0][0][0].flatten().shape[0]))
train_text = torch.zeros((len(subjects_samples[0]), 4*100))
for idx, sample in enumerate(subjects_samples[0]):
    train_images[idx] = torch.tensor(sample[0].flatten())
    for word_idx, word in enumerate(sample[1]):
        stripped_word = re.sub("(?<!Mr)(?<!Mrs)\.($| )", "", word)
        train_text[idx, word_idx*100:(word_idx+1)*100] = torch.tensor(word_vec_encoder.wv[stripped_word])
train_images = train_images.float().to(device)
print("before max:", torch.max(train_images))
print("before min:", torch.min(train_images))
train_images -= train_images.min(1, keepdim=True)[0]
train_images /= train_images.max(1, keepdim=True)[0]
print("after max:", torch.max(train_images))
print("after min:", torch.min(train_images))
train_text = train_text.float().to(device)
print(train_images.shape)
print(train_text.shape)

before max: tensor(750.8578, device='cuda:0')
before min: tensor(-0.0949, device='cuda:0')
after max: tensor(1., device='cuda:0')
after min: tensor(0., device='cuda:0')
torch.Size([1287, 159000])
torch.Size([1287, 400])


In [21]:
train_split = 0.05
temp = 0.07

train_samples = subjects_samples[0][:int(len(subjects_samples[0])*train_split)]
train_images, train_text = split_samples(train_samples)
print("Train:")
print(len(train_text))
print(train_images.shape)

test_samples = subjects_samples[0][int(len(subjects_samples[0])*train_split):]
test_images, test_text = split_samples(test_samples)
print("Test:")
print(len(test_text))
print(test_images.shape)

Train:
64
(64, 53, 60, 50)
Test:
1223
(1223, 53, 60, 50)


In [171]:
print(train_images[0].flatten().shape)

(159000,)


In [44]:
# text_encoder = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased').to(device)
# image_encoder = ResNet(ResidualBlock, [3, 4, 6, 3], num_classes=512).to(device)
text_encoder = word_vec_encoder
image_encoder = SimpleCNN().to(device)
#image_encoder = TestModel(train_images[0].flatten().shape[0],1024).to(device)
clip_model = CLIP(text_encoder, 4*100, image_encoder, 1152, 50)

In [45]:
train_clip(clip_model, train_text, train_images, batch_size=10, lr=1e-6, num_epochs=1000, temp=0.07)

	Learning_rate: 1e-06
tensor([[204.4028, 203.9049, 204.1459, 204.2730, 204.3450, 203.7423, 203.9562,
         203.8020, 203.5014, 204.2075],
        [204.5186, 204.0218, 204.2625, 204.3874, 204.4592, 203.8589, 204.0728,
         203.9180, 203.6172, 204.3225],
        [204.3156, 203.8184, 204.0591, 204.1852, 204.2578, 203.6552, 203.8698,
         203.7150, 203.4150, 204.1202],
        [204.3568, 203.8591, 204.1001, 204.2241, 204.2968, 203.6959, 203.9103,
         203.7543, 203.4533, 204.1594],
        [204.5398, 204.0446, 204.2851, 204.4107, 204.4835, 203.8808, 204.0958,
         203.9421, 203.6424, 204.3465],
        [204.4222, 203.9247, 204.1654, 204.2902, 204.3619, 203.7621, 203.9756,
         203.8205, 203.5195, 204.2252],
        [204.3257, 203.8297, 204.0709, 204.1964, 204.2710, 203.6649, 203.8818,
         203.7267, 203.4273, 204.1321],
        [204.4222, 203.9266, 204.1682, 204.2901, 204.3654, 203.7612, 203.9788,
         203.8220, 203.5215, 204.2267],
        [204.4247, 203.927

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.7949, 204.1144, 204.2184, 203.9672, 204.0778, 203.2627, 203.9535,
         203.9403, 204.3383, 204.7593],
        [203.8488, 204.1674, 204.2700, 204.0198, 204.1302, 203.3156, 204.0060,
         203.9932, 204.3901, 204.8109],
        [203.8511, 204.1686, 204.2701, 204.0223, 204.1308, 203.3174, 204.0067,
         203.9951, 204.3905, 204.8117],
        [203.7532, 204.0730, 204.1763, 203.9254, 204.0359, 203.2208, 203.9114,
         203.8986, 204.2963, 204.7174],
        [203.9285, 204.2474, 204.3506, 204.1013, 204.2106, 203.3963, 204.0860,
         204.0737, 204.4705, 204.8927],
        [203.8103, 204.1271, 204.2287, 203.9816, 204.0892, 203.2767, 203.9654,
         203.9542, 204.3491, 204.7698],
        [203.7208, 204.0384, 204.1415, 203.8916, 204.0011, 203.1875, 203.8775,
         203.8651, 204.2616, 204.6811],
        [203.6376, 203.9571, 204.0609, 203.8110, 203.91

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[204.5667, 204.1059, 204.1758, 203.8549, 204.0433, 204.1698, 203.7427,
         204.8625, 203.9787, 203.2270],
        [204.5339, 204.0708, 204.1422, 203.8220, 204.0063, 204.1331, 203.7064,
         204.8276, 203.9450, 203.1914],
        [204.4796, 204.0163, 204.0892, 203.7688, 203.9533, 204.0784, 203.6528,
         204.7733, 203.8903, 203.1366],
        [204.4719, 204.0094, 204.0797, 203.7596, 203.9444, 204.0721, 203.6445,
         204.7660, 203.8842, 203.1297],
        [204.4500, 203.9877, 204.0602, 203.7394, 203.9263, 204.0509, 203.6253,
         204.7450, 203.8608, 203.1085],
        [204.4988, 204.0368, 204.1095, 203.7883, 203.9758, 204.0999, 203.6747,
         204.7939, 203.9096, 203.1575],
        [204.2486, 203.7865, 203.8570, 203.5369, 203.7229, 203.8497, 203.4224,
         204.5432, 203.6608, 202.9065],
        [204.6391, 204.1776, 204.2495, 203.9284, 204.11

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.7008, 204.0482, 204.1113, 204.1299, 204.1205, 203.5883, 203.2225,
         204.0295, 203.6957, 203.8611],
        [203.6825, 204.0301, 204.0923, 204.1117, 204.1018, 203.5675, 203.2013,
         204.0137, 203.6775, 203.8415],
        [203.6474, 203.9962, 204.0592, 204.0777, 204.0692, 203.5365, 203.1700,
         203.9786, 203.6436, 203.8092],
        [203.6447, 203.9919, 204.0542, 204.0736, 204.0629, 203.5286, 203.1623,
         203.9756, 203.6395, 203.8029],
        [203.6273, 203.9756, 204.0377, 204.0576, 204.0488, 203.5143, 203.1480,
         203.9590, 203.6229, 203.7875],
        [203.7137, 204.0619, 204.1253, 204.1434, 204.1346, 203.6026, 203.2365,
         204.0436, 203.7094, 203.8753],
        [203.6848, 204.0320, 204.0944, 204.1140, 204.1045, 203.5711, 203.2054,
         204.0137, 203.6791, 203.8438],
        [203.7177, 204

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[205.0664, 204.5719, 203.5524, 205.2182, 204.0799, 204.9565, 203.7391,
         203.9603, 204.1038, 203.9410],
        [204.6155, 204.1214, 203.0998, 204.7672, 203.6265, 204.5051, 203.2870,
         203.5089, 203.6536, 203.4880],
        [204.7497, 204.2559, 203.2326, 204.9026, 203.7599, 204.6393, 203.4212,
         203.6438, 203.7880, 203.6232],
        [204.8228, 204.3308, 203.3088, 204.9754, 203.8348, 204.7137, 203.4975,
         203.7177, 203.8636, 203.6966],
        [204.8287, 204.3357, 203.3124, 204.9823, 203.8397, 204.7191, 203.5020,
         203.7240, 203.8685, 203.7027],
        [204.9838, 204.4923, 203.4689, 205.1370, 203.9954, 204.8746, 203.6586,
         203.8788, 204.0247, 203.8586],
        [204.9086, 204.4184, 203.3942, 205.0619, 203.9198, 204.7998, 203.5845,
         203.8039, 203.9511, 203.7831],
        [204.8529, 204

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.8203, 203.6234, 203.5329, 203.7997, 203.6935, 203.6188, 203.6213,
         203.4113, 204.2576, 203.1840],
        [203.7942, 203.5977, 203.5075, 203.7741, 203.6681, 203.5928, 203.5957,
         203.3855, 204.2319, 203.1580],
        [203.9212, 203.7242, 203.6352, 203.9012, 203.7946, 203.7189, 203.7239,
         203.5123, 204.3582, 203.2863],
        [203.7932, 203.5958, 203.5050, 203.7720, 203.6659, 203.5918, 203.5935,
         203.3838, 204.2305, 203.1562],
        [204.1514, 203.9562, 203.8679, 204.1329, 204.0256, 203.9498, 203.9555,
         203.7441, 204.5885, 203.5196],
        [204.1847, 203.9890, 203.8999, 204.1645, 204.0582, 203.9836, 203.9871,
         203.7767, 204.6221, 203.5515],
        [203.9800, 203.7852, 203.6954, 203.9614, 203.8542, 203.7796, 203.7823,
         203.5728, 204.4174, 203.3471],
        [203.9313, 203

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[204.0021, 204.3209, 203.8048, 203.7949, 204.1480, 203.9970, 204.6581,
         203.2379, 203.8243, 203.7108],
        [203.8987, 204.2191, 203.7024, 203.6923, 204.0450, 203.8938, 204.5562,
         203.1332, 203.7216, 203.6086],
        [203.7424, 204.0606, 203.5452, 203.5365, 203.8878, 203.7387, 204.3981,
         202.9767, 203.5640, 203.4515],
        [204.1336, 204.4524, 203.9364, 203.9271, 204.2787, 204.1290, 204.7888,
         203.3699, 203.9555, 203.8424],
        [203.8111, 204.1320, 203.6154, 203.6041, 203.9585, 203.8059, 204.4698,
         203.0453, 203.6344, 203.5218],
        [203.7686, 204.0863, 203.5709, 203.5625, 203.9137, 203.7647, 204.4237,
         203.0031, 203.5900, 203.4770],
        [203.8051, 204.1222, 203.6080, 203.5982, 203.9509, 203.8015, 204.4603,
         203.0405, 203.6264, 203.5144],
        [203.8526, 204

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[204.0490, 204.2369, 203.8018, 204.4740, 203.8430, 203.1669, 203.9467,
         203.9303, 203.6591, 203.5012],
        [203.7727, 203.9590, 203.5259, 204.1969, 203.5672, 202.8907, 203.6712,
         203.6535, 203.3828, 203.2239],
        [203.9618, 204.1487, 203.7150, 204.3867, 203.7551, 203.0805, 203.8601,
         203.8422, 203.5717, 203.4136],
        [203.9679, 204.1548, 203.7209, 204.3935, 203.7602, 203.0838, 203.8643,
         203.8470, 203.5763, 203.4191],
        [204.0691, 204.2561, 203.8228, 204.4946, 203.8615, 203.1888, 203.9677,
         203.9489, 203.6789, 203.5213],
        [204.0706, 204.2578, 203.8241, 204.4964, 203.8631, 203.1890, 203.9683,
         203.9504, 203.6800, 203.5225],
        [203.9951, 204.1814, 203.7504, 204.4209, 203.7919, 203.1156, 203.8950,
         203.8780, 203.6066, 203.4470],
        [203.8221, 204

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.4650, 204.4183, 203.8746, 203.3162, 204.2582, 203.4031, 203.0229,
         203.5797, 203.5886, 202.8424],
        [203.7102, 204.6637, 204.1202, 203.5616, 204.5023, 203.6486, 203.2684,
         203.8240, 203.8324, 203.0898],
        [203.2802, 204.2333, 203.6898, 203.1315, 204.0734, 203.2182, 202.8374,
         203.3937, 203.4035, 202.6557],
        [203.6940, 204.6475, 204.1043, 203.5453, 204.4858, 203.6329, 203.2526,
         203.8084, 203.8166, 203.0744],
        [203.5580, 204.5114, 203.9681, 203.4095, 204.3500, 203.4958, 203.1172,
         203.6730, 203.6816, 202.9369],
        [203.7178, 204.6714, 204.1281, 203.5690, 204.5096, 203.6569, 203.2760,
         203.8316, 203.8399, 203.0981],
        [203.4755, 204.4287, 203.8848, 203.3268, 204.2690, 203.4133, 203.0325,
         203.5888, 203.5981, 202.8518],
        [203.8801, 204

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.4430, 203.5264, 203.7141, 203.3857, 204.5735, 203.6217, 203.2492,
         203.1771, 204.5006, 203.5840],
        [203.4954, 203.5777, 203.7666, 203.4393, 204.6268, 203.6734, 203.3011,
         203.2296, 204.5529, 203.6363],
        [203.3680, 203.4504, 203.6391, 203.3101, 204.4988, 203.5463, 203.1733,
         203.1020, 204.4256, 203.5092],
        [203.6604, 203.7437, 203.9319, 203.6016, 204.7904, 203.8385, 203.4674,
         203.3944, 204.7178, 203.8016],
        [203.5969, 203.6797, 203.8689, 203.5389, 204.7268, 203.7741, 203.4029,
         203.3308, 204.6549, 203.7371],
        [203.6524, 203.7364, 203.9229, 203.5973, 204.7840, 203.8318, 203.4599,
         203.3869, 204.7094, 203.7940],
        [203.6145, 203.6977, 203.8860, 203.5557, 204.7444, 203.7928, 203.4214,
         203.3485, 204.6719, 203.7558],
        [203.6123, 203.6949, 203.8843, 203.5542, 204.74

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[204.6940, 203.5068, 203.5101, 203.5269, 204.3791, 204.2263, 203.1508,
         203.2987, 203.7423, 203.8431],
        [204.7575, 203.5695, 203.5734, 203.5908, 204.4446, 204.2909, 203.2150,
         203.3623, 203.8067, 203.9073],
        [204.5715, 203.3833, 203.3862, 203.4031, 204.2570, 204.1019, 203.0278,
         203.1753, 203.6195, 203.7215],
        [204.7473, 203.5611, 203.5625, 203.5791, 204.4313, 204.2779, 203.2038,
         203.3515, 203.7957, 203.8953],
        [204.6118, 203.4233, 203.4261, 203.4430, 204.2977, 204.1417, 203.0679,
         203.2153, 203.6599, 203.7620],
        [204.7151, 203.5280, 203.5301, 203.5465, 204.3989, 204.2447, 203.1708,
         203.3190, 203.7626, 203.8639],
        [204.7952, 203.6073, 203.6110, 203.6283, 204.4825, 204.3283, 203.2526,
         203.3999, 203.8444, 203.9448],
        [204.7094, 203

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.8022, 203.5308, 203.1595, 204.2344, 203.5719, 204.3645, 203.4889,
         203.4654, 203.0837, 203.5094],
        [202.6219, 203.3493, 202.9789, 204.0525, 203.3917, 204.1837, 203.3064,
         203.2846, 202.9022, 203.3282],
        [202.8589, 203.5870, 203.2151, 204.2903, 203.6279, 204.4203, 203.5453,
         203.5214, 203.1388, 203.5662],
        [202.7780, 203.5049, 203.1333, 204.2067, 203.5469, 204.3382, 203.4614,
         203.4398, 203.0564, 203.4838],
        [202.6039, 203.3310, 202.9607, 204.0347, 203.3731, 204.1652, 203.2892,
         203.2658, 202.8831, 203.3110],
        [202.6972, 203.4240, 203.0534, 204.1259, 203.4660, 204.2569, 203.3815,
         203.3591, 202.9765, 203.4037],
        [202.8721, 203.5991, 203.2273, 204.3007, 203.6404, 204.4312, 203.5569,
         203.5336, 203.1506, 203.5790],
        [202.7604, 203

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.4960, 203.1449, 203.5166, 203.6799, 204.3735, 203.1152, 203.4104,
         203.5003, 203.5225, 203.5760],
        [203.4975, 203.1471, 203.5191, 203.6821, 204.3756, 203.1165, 203.4126,
         203.5012, 203.5253, 203.5771],
        [203.2595, 202.9111, 203.2814, 203.4480, 204.1359, 202.8816, 203.1778,
         203.2679, 203.2877, 203.3428],
        [203.2722, 202.9209, 203.2916, 203.4556, 204.1496, 202.8901, 203.1849,
         203.2750, 203.2982, 203.3517],
        [203.3137, 202.9625, 203.3335, 203.4975, 204.1909, 202.9322, 203.2272,
         203.3173, 203.3398, 203.3937],
        [203.2526, 202.9026, 203.2734, 203.4383, 204.1298, 202.8718, 203.1678,
         203.2571, 203.2802, 203.3340],
        [203.3994, 203.0492, 203.4202, 203.5847, 204.2764, 203.0197, 203.3149,
         203.4053, 203.4260, 203.4802],
        [203.4599, 203

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.9450, 203.3731, 203.5759, 203.4551, 203.2416, 203.7045, 203.5472,
         202.9917, 203.3927, 204.0080],
        [203.0566, 203.4847, 203.6848, 203.5651, 203.3523, 203.8159, 203.6586,
         203.1028, 203.5034, 204.1169],
        [202.8366, 203.2645, 203.4670, 203.3464, 203.1328, 203.5964, 203.4394,
         202.8836, 203.2839, 203.8990],
        [202.8531, 203.2809, 203.4820, 203.3626, 203.1489, 203.6127, 203.4561,
         202.9004, 203.2999, 203.9149],
        [202.9423, 203.3712, 203.5754, 203.4517, 203.2391, 203.7029, 203.5440,
         202.9876, 203.3906, 204.0039],
        [202.6945, 203.1238, 203.3277, 203.2056, 202.9912, 203.4563, 203.2976,
         202.7416, 203.1430, 203.7585],
        [203.1871, 203.6153, 203.8177, 203.6958, 203.4837, 203.9461, 203.7880,
         203.2322, 203.6347, 204.2479],
        [202.7801, 203

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[204.1152, 203.2372, 203.4124, 203.6298, 203.2167, 202.8881, 203.2497,
         203.1930, 204.0896, 203.7580],
        [204.4652, 203.5868, 203.7634, 203.9794, 203.5674, 203.2379, 203.6001,
         203.5444, 204.4403, 204.1095],
        [204.4593, 203.5806, 203.7579, 203.9738, 203.5619, 203.2323, 203.5946,
         203.5387, 204.4348, 204.1039],
        [204.3332, 203.4548, 203.6351, 203.8503, 203.4366, 203.1084, 203.4694,
         203.4146, 204.3109, 203.9792],
        [204.3122, 203.4341, 203.6098, 203.8261, 203.4139, 203.0849, 203.4465,
         203.3907, 204.2865, 203.9553],
        [204.2999, 203.4207, 203.5975, 203.8149, 203.4023, 203.0725, 203.4355,
         203.3783, 204.2755, 203.9445],
        [204.1188, 203.2408, 203.4214, 203.6384, 203.2224, 202.8951, 203.2559,
         203.2004, 204.0983, 203.7659],
        [204.1998, 203

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.2756, 202.9374, 202.8323, 203.6932, 203.3190, 203.2335, 203.6086,
         203.5386, 203.5487, 203.6636],
        [203.2206, 202.8823, 202.7772, 203.6383, 203.2638, 203.1782, 203.5535,
         203.4837, 203.4940, 203.6085],
        [202.9543, 202.6139, 202.5119, 203.3716, 202.9966, 202.9101, 203.2895,
         203.2161, 203.2276, 203.3434],
        [203.1512, 202.8120, 202.7084, 203.5689, 203.1944, 203.1084, 203.4852,
         203.4134, 203.4241, 203.5398],
        [203.2042, 202.8663, 202.7619, 203.6210, 203.2473, 203.1615, 203.5388,
         203.4677, 203.4782, 203.5932],
        [203.2179, 202.8797, 202.7765, 203.6344, 203.2613, 203.1757, 203.5533,
         203.4805, 203.4910, 203.6077],
        [203.2195, 202.8821, 202.7758, 203.6372, 203.2630, 203.1774, 203.5519,
         203.4839, 203.4939, 203.6073],
        [203.4615, 203

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
Epoch: 20 Training Loss: 14.085168838500977 Training Accuracy: 0.1
	Learning_rate: 1e-06
tensor([[203.6121, 202.7527, 203.2902, 203.7022, 203.6982, 203.5695, 203.3336,
         203.1062, 203.7593, 203.5082],
        [203.5179, 202.6604, 203.1967, 203.6082, 203.6039, 203.4763, 203.2397,
         203.0131, 203.6662, 203.4157],
        [203.6446, 202.7883, 203.3237, 203.7345, 203.7304, 203.6034, 203.3672,
         203.1403, 203.7943, 203.5443],
        [203.6683, 202.8104, 203.3464, 203.7587, 203.7539, 203.6272, 203.3907,
         203.1638, 203.8152, 203.5652],
        [203.8072, 202.9506, 203.4860, 203.8969, 203.8931, 203.7660, 203.5303,
         203.3027, 203.9548, 203.7050],
        [203.5170, 202.6564, 203.1943, 203.6077, 203.6028, 203.4745, 203.2380,
         203.0111, 203.6637, 203.4125],
        [203.3965, 202.5365, 203.0734, 203.4877, 203

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
Epoch: 21 Training Loss: 14.067477703094482 Training Accuracy: 0.1
	Learning_rate: 1e-06
tensor([[202.6149, 203.1715, 203.5659, 203.3611, 204.0876, 203.2216, 204.2254,
         202.8662, 203.4569, 203.2972],
        [202.6017, 203.1581, 203.5531, 203.3474, 204.0760, 203.2085, 204.2123,
         202.8531, 203.4439, 203.2836],
        [202.7849, 203.3423, 203.7372, 203.5315, 204.2599, 203.3917, 204.3958,
         203.0368, 203.6282, 203.4665],
        [202.4972, 203.0529, 203.4479, 203.2408, 203.9710, 203.1039, 204.1075,
         202.7481, 203.3390, 203.1804],
        [202.5147, 203.0706, 203.4664, 203.2580, 203.9914, 203.1214, 204.1254,
         202.7662, 203.3572, 203.1969],
        [202.4749, 203.0306, 203.4261, 203.2183, 203.9508, 203.0816, 204.0854,
         202.7261, 203.3169, 203.1573],
        [202.6898, 203.2459, 203.6412, 203.4335, 204

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.6839, 203.4167, 203.5315, 203.3836, 203.4385, 203.2678, 203.3626,
         203.7041, 203.6846, 203.5007],
        [203.3177, 203.0492, 203.1627, 203.0162, 203.0706, 202.8998, 202.9931,
         203.3368, 203.3164, 203.1320],
        [203.7317, 203.4638, 203.5770, 203.4312, 203.4858, 203.3145, 203.4075,
         203.7509, 203.7311, 203.5467],
        [203.4316, 203.1636, 203.2766, 203.1305, 203.1858, 203.0141, 203.1075,
         203.4508, 203.4311, 203.2457],
        [203.3979, 203.1310, 203.2454, 203.0967, 203.1528, 202.9822, 203.0774,
         203.4181, 203.3997, 203.2139],
        [203.5455, 203.2779, 203.3912, 203.2448, 203.3006, 203.1286, 203.2223,
         203.5650, 203.5458, 203.3601],
        [203.5355, 203.2664, 203.3790, 203.2341, 203.2877, 203.1168, 203.2087,
         203.5537, 203.5329, 203.3490],
        [203.4180, 203.1496, 203.2631, 203.1164, 203.17

tensor([[202.5164, 203.2917, 203.0941, 202.9875, 202.6931, 203.4503, 202.9245,
         202.5866, 203.1120, 203.0282],
        [202.7610, 203.5344, 203.3402, 203.2328, 202.9372, 203.6941, 203.1715,
         202.8306, 203.3585, 203.2720],
        [202.6136, 203.3865, 203.1924, 203.0846, 202.7884, 203.5466, 203.0247,
         202.6825, 203.2120, 203.1247],
        [202.5172, 203.2906, 203.0972, 202.9885, 202.6926, 203.4512, 202.9270,
         202.5881, 203.1158, 203.0285],
        [202.5197, 203.2939, 203.0987, 202.9910, 202.6953, 203.4528, 202.9310,
         202.5889, 203.1183, 203.0312],
        [202.2707, 203.0445, 202.8498, 202.7408, 202.4451, 203.2047, 202.6807,
         202.3410, 202.8698, 202.7825],
        [202.5164, 203.2896, 203.0956, 202.9873, 202.6911, 203.4494, 202.9282,
         202.5854, 203.1157, 203.0277],
        [202.2535, 203.0265, 202.8321, 202.7231, 202.4271, 203.1876, 202.6629,
         202.3237, 202.8522, 202.7652],
        [202.5988, 203.3719, 203.1780, 203.0700,

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.4737, 204.1501, 202.8772, 203.2979, 203.3252, 203.3261, 203.4973,
         203.1111, 203.0950, 204.4817],
        [203.1529, 203.8279, 202.5545, 202.9763, 203.0022, 203.0040, 203.1769,
         202.7893, 202.7737, 204.1612],
        [203.2332, 203.9086, 202.6353, 203.0578, 203.0833, 203.0848, 203.2579,
         202.8707, 202.8543, 204.2419],
        [203.3113, 203.9873, 202.7139, 203.1371, 203.1625, 203.1638, 203.3371,
         202.9499, 202.9332, 204.3207],
        [203.0789, 203.7548, 202.4803, 202.9009, 202.9275, 202.9297, 203.1008,
         202.7133, 202.6991, 204.0860],
        [203.2659, 203.9418, 202.6681, 203.0886, 203.1158, 203.1174, 203.2888,
         202.9017, 202.8867, 204.2736],
        [203.1143, 203.7913, 202.5161, 202.9407, 202.9652, 202.9673, 203.1402,
         202.7523, 202.7364, 204.1237],
        [203.2289, 203

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.0936, 203.6119, 203.5627, 203.2969, 203.1819, 202.8897, 203.3520,
         203.1002, 203.4449, 202.7462],
        [203.0263, 203.5481, 203.4962, 203.2312, 203.1172, 202.8239, 203.2865,
         203.0348, 203.3788, 202.6804],
        [202.6558, 203.1752, 203.1253, 202.8594, 202.7455, 202.4525, 202.9147,
         202.6631, 203.0072, 202.3096],
        [202.9053, 203.4243, 203.3754, 203.1087, 202.9945, 202.7028, 203.1651,
         202.9131, 203.2575, 202.5596],
        [202.9553, 203.4736, 203.4247, 203.1589, 203.0436, 202.7520, 203.2143,
         202.9622, 203.3072, 202.6091],
        [203.0081, 203.5302, 203.4783, 203.2128, 203.0996, 202.8062, 203.2687,
         203.0171, 203.3607, 202.6623],
        [202.9775, 203.4993, 203.4473, 203.1825, 203.0683, 202.7752, 203.2378,
         202.9859, 203.3301, 202.6318],
        [202.8709, 203

tensor([[203.0772, 203.4225, 203.3732, 203.0686, 203.1011, 203.0296, 203.3306,
         203.2727, 202.4756, 202.6913],
        [203.1326, 203.4796, 203.4300, 203.1246, 203.1583, 203.0862, 203.3867,
         203.3296, 202.5309, 202.7476],
        [203.0261, 203.3722, 203.3236, 203.0193, 203.0515, 202.9783, 203.2792,
         203.2226, 202.4243, 202.6408],
        [203.2135, 203.5592, 203.5096, 203.2067, 203.2382, 203.1658, 203.4663,
         203.4101, 202.6116, 202.8286],
        [202.9061, 203.2535, 203.2039, 202.9011, 202.9320, 202.8592, 203.1602,
         203.1045, 202.3050, 202.5219],
        [202.9496, 203.2957, 203.2474, 202.9415, 202.9750, 202.9021, 203.2032,
         203.1456, 202.3480, 202.5637],
        [202.9308, 203.2768, 203.2287, 202.9233, 202.9562, 202.8831, 203.1842,
         203.1270, 202.3292, 202.5450],
        [203.0815, 203.4283, 203.3789, 203.0746, 203.1071, 203.0345, 203.3352,
         203.2787, 202.4799, 202.6967],
        [202.9628, 203.3086, 203.2603, 202.9536,

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.9768, 202.9473, 203.4175, 202.8624, 202.5722, 203.5804, 202.9292,
         204.0534, 203.4673, 202.7455],
        [203.1912, 203.1627, 203.6315, 203.0761, 202.7856, 203.7950, 203.1447,
         204.2668, 203.6821, 202.9595],
        [202.5864, 202.5576, 203.0272, 202.4717, 202.1804, 203.1904, 202.5385,
         203.6635, 203.0767, 202.3549],
        [202.9994, 202.9715, 203.4407, 202.8844, 202.5946, 203.6044, 202.9532,
         204.0751, 203.4892, 202.7682],
        [202.7816, 202.7525, 203.2218, 202.6666, 202.3751, 203.3849, 202.7341,
         203.8583, 203.2727, 202.5498],
        [202.6475, 202.6178, 203.0885, 202.5330, 202.2428, 203.2509, 202.5997,
         203.7246, 203.1375, 202.4163],
        [203.0379, 203.0103, 203.4791, 202.9230, 202.6327, 203.6431, 202.9914,
         204.1138, 203.5281, 202.8066],
        [203.0975, 203

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.4482, 202.1946, 202.6887, 203.0575, 202.5681, 202.8903, 202.4587,
         202.1633, 202.9920, 202.7296],
        [202.8478, 202.5929, 203.0879, 203.4587, 202.9678, 203.2891, 202.8569,
         202.5613, 203.3918, 203.1261],
        [202.6832, 202.4284, 202.9223, 203.2934, 202.8029, 203.1243, 202.6915,
         202.3963, 203.2272, 202.9616],
        [202.7799, 202.5251, 203.0205, 203.3911, 202.9002, 203.2214, 202.7897,
         202.4937, 203.3235, 203.0591],
        [202.7994, 202.5444, 203.0400, 203.4088, 202.9189, 203.2405, 202.8100,
         202.5142, 203.3424, 203.0789],
        [202.6924, 202.4387, 202.9344, 203.3042, 202.8133, 203.1350, 202.7037,
         202.4074, 203.2362, 202.9734],
        [202.7505, 202.4952, 202.9894, 203.3581, 202.8690, 203.1908, 202.7594,
         202.4647, 203.2940, 203.0289],
        [202.6038, 202.3491, 202.8432, 203.2143, 202.72

tensor([4, 4, 2, 2, 4, 4, 4, 2, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.9196, 202.9400, 202.9059, 202.9759, 203.5892, 202.3381, 202.3717,
         203.0193, 202.8030, 203.3482],
        [202.9992, 203.0193, 202.9836, 203.0547, 203.6684, 202.4171, 202.4497,
         203.0993, 202.8815, 203.4265],
        [202.8844, 202.9028, 202.8672, 202.9383, 203.5550, 202.3018, 202.3347,
         202.9828, 202.7664, 203.3125],
        [202.9935, 203.0135, 202.9778, 203.0485, 203.6624, 202.4107, 202.4438,
         203.0932, 202.8748, 203.4194],
        [202.9257, 202.9444, 202.9091, 202.9803, 203.5962, 202.3437, 202.3762,
         203.0243, 202.8084, 203.3544],
        [202.9354, 202.9565, 202.9223, 202.9919, 203.6041, 202.3534, 202.3876,
         203.0356, 202.8181, 203.3625],
        [203.1109, 203.1295, 203.0931, 203.1647, 203.7804, 202.5280, 202.5603,
         203.2095, 202.9916, 203.5364],
        [203.0892, 203.1083, 203.0720, 203.1434, 203.75

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.2460, 203.1422, 202.4345, 202.2446, 202.6305, 202.7091, 202.8213,
         203.0152, 202.0690, 202.5326],
        [202.4225, 203.3190, 202.6110, 202.4214, 202.8069, 202.8854, 202.9971,
         203.1909, 202.2462, 202.7086],
        [202.4068, 203.3001, 202.5970, 202.4056, 202.7900, 202.8700, 202.9835,
         203.1723, 202.2325, 202.6938],
        [202.3185, 203.2118, 202.5053, 202.3150, 202.7010, 202.7794, 202.8914,
         203.0851, 202.1408, 202.6027],
        [202.2663, 203.1611, 202.4556, 202.2650, 202.6503, 202.7296, 202.8426,
         203.0338, 202.0904, 202.5533],
        [202.3195, 203.2112, 202.5064, 202.3157, 202.7016, 202.7806, 202.8930,
         203.0848, 202.1417, 202.6039],
        [202.2918, 203.1866, 202.4816, 202.2907, 202.6757, 202.7552, 202.8684,
         203.0592, 202.1165, 202.5789],
        [202.5275, 203

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.0717, 202.6660, 203.3878, 203.3154, 203.1969, 203.4989, 203.0576,
         202.9594, 202.4511, 203.0146],
        [202.9510, 202.5462, 203.2690, 203.1956, 203.0779, 203.3795, 202.9376,
         202.8406, 202.3296, 202.8939],
        [202.9689, 202.5639, 203.2868, 203.2133, 203.0953, 203.3969, 202.9554,
         202.8581, 202.3479, 202.9121],
        [202.9353, 202.5302, 203.2512, 203.1802, 203.0627, 203.3648, 202.9215,
         202.8236, 202.3132, 202.8773],
        [202.8884, 202.4819, 203.2027, 203.1336, 203.0133, 203.3171, 202.8726,
         202.7753, 202.2647, 202.8286],
        [202.6588, 202.2542, 202.9749, 202.9042, 202.7862, 203.0872, 202.6473,
         202.5473, 202.0375, 202.6025],
        [202.8120, 202.4056, 203.1267, 203.0574, 202.9369, 203.2407, 202.7965,
         202.6991, 202.1880, 202.7523],
        [202.8977, 202.4920, 203.2123, 203.1421, 203.02

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.0499, 201.8271, 202.6104, 203.1408, 202.8067, 202.7238, 202.4198,
         201.7236, 202.2934, 203.5220],
        [203.3411, 202.1206, 202.9035, 203.4347, 203.1002, 203.0171, 202.7114,
         202.0155, 202.5879, 203.8152],
        [203.2034, 201.9787, 202.7595, 203.2922, 202.9588, 202.8751, 202.5739,
         201.8765, 202.4453, 203.6724],
        [203.3484, 202.1249, 202.9054, 203.4367, 203.1026, 203.0198, 202.7180,
         202.0210, 202.5908, 203.8187],
        [203.6022, 202.3800, 203.1603, 203.6929, 203.3589, 203.2753, 202.9727,
         202.2756, 202.8479, 204.0732],
        [203.4059, 202.1837, 202.9642, 203.4981, 203.1643, 203.0804, 202.7771,
         202.0798, 202.6510, 203.8775],
        [203.5937, 202.3709, 203.1507, 203.6838, 203.3501, 203.2663, 202.9644,
         202.2669, 202.8387, 204.0639],
        [203.3698, 202

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.7692, 202.9564, 202.9073, 203.5742, 201.9768, 202.0444, 202.3828,
         202.4829, 202.8944, 203.4203],
        [203.1408, 203.3289, 203.2805, 203.9472, 202.3504, 202.4176, 202.7560,
         202.8562, 203.2671, 203.7934],
        [202.6376, 202.8259, 202.7768, 203.4433, 201.8458, 201.9139, 202.2511,
         202.3523, 202.7632, 203.2902],
        [202.9578, 203.1446, 203.0962, 203.7630, 202.1660, 202.2330, 202.5721,
         202.6720, 203.0829, 203.6087],
        [202.8796, 203.0703, 203.0220, 203.6881, 202.0911, 202.1592, 202.4944,
         202.5976, 203.0067, 203.5360],
        [202.6201, 202.8086, 202.7598, 203.4259, 201.8286, 201.8967, 202.2333,
         202.3350, 202.7455, 203.2731],
        [202.8844, 203.0698, 203.0212, 203.6885, 202.0912, 202.1581, 202.4983,
         202.5972, 203.0087, 203.5336],
        [202.9421, 203

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.9005, 202.8914, 202.5481, 202.8730, 202.0257, 202.7157, 202.5083,
         203.2709, 202.5794, 202.3606],
        [202.9048, 202.8959, 202.5529, 202.8777, 202.0299, 202.7207, 202.5126,
         203.2751, 202.5837, 202.3655],
        [203.0521, 203.0436, 202.7029, 203.0267, 202.1767, 202.8721, 202.6607,
         203.4214, 202.7313, 202.5163],
        [202.8841, 202.8758, 202.5340, 202.8586, 202.0093, 202.7019, 202.4919,
         203.2546, 202.5632, 202.3460],
        [202.7079, 202.6996, 202.3579, 202.6813, 201.8331, 202.5262, 202.3153,
         203.0781, 202.3870, 202.1694],
        [202.7642, 202.7551, 202.4127, 202.7363, 201.8884, 202.5816, 202.3719,
         203.1331, 202.4428, 202.2257],
        [202.9573, 202.9483, 202.6051, 202.9302, 202.0824, 202.7728, 202.5652,
         203.3277, 202.6362, 202.4179],
        [202.9733, 202

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.6728, 202.7042, 203.2074, 202.9348, 201.9958, 203.8660, 203.1430,
         203.4096, 202.8338, 203.7888],
        [202.4815, 202.5139, 203.0183, 202.7447, 201.8040, 203.6747, 202.9515,
         203.2210, 202.6459, 203.6000],
        [202.5922, 202.6244, 203.1289, 202.8540, 201.9146, 203.7843, 203.0619,
         203.3310, 202.7546, 203.7096],
        [202.6143, 202.6458, 203.1485, 202.8775, 201.9368, 203.8079, 203.0845,
         203.3514, 202.7773, 203.7307],
        [202.6045, 202.6369, 203.1401, 202.8664, 201.9254, 203.7956, 203.0736,
         203.3428, 202.7672, 203.7205],
        [202.5582, 202.5904, 203.0957, 202.8205, 201.8813, 203.7513, 203.0283,
         203.2975, 202.7212, 203.6767],
        [202.6109, 202.6425, 203.1454, 202.8731, 201.9342, 203.8043, 203.0814,
         203.3477, 202.7719, 203.7269],
        [202.5779, 202

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.6409, 202.9608, 202.7458, 202.9963, 202.8322, 202.5699, 202.7440,
         203.0436, 203.1047, 202.0668],
        [202.6074, 202.9281, 202.7108, 202.9610, 202.7982, 202.5347, 202.7092,
         203.0113, 203.0696, 202.0328],
        [202.4538, 202.7753, 202.5598, 202.8091, 202.6488, 202.3836, 202.5603,
         202.8571, 202.9182, 201.8813],
        [202.5838, 202.9052, 202.6877, 202.9364, 202.7779, 202.5114, 202.6877,
         202.9878, 203.0457, 202.0100],
        [202.7884, 203.1083, 202.8921, 203.1416, 202.9797, 202.7160, 202.8902,
         203.1915, 203.2502, 202.2139],
        [202.2279, 202.5501, 202.3328, 202.5833, 202.4222, 202.1568, 202.3338,
         202.6324, 202.6925, 201.6545],
        [202.7612, 203.0811, 202.8651, 203.1140, 202.9537, 202.6889, 202.8635,
         203.1641, 203.2228, 202.1868],
        [202.8029, 203

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.4077, 202.6771, 202.5631, 202.8696, 203.1606, 203.0516, 202.7772,
         202.9853, 202.7569, 202.7065],
        [202.2883, 202.5581, 202.4440, 202.7517, 203.0443, 202.9308, 202.6587,
         202.8662, 202.6372, 202.5887],
        [202.4426, 202.7124, 202.5975, 202.9035, 203.1961, 203.0850, 202.8121,
         203.0193, 202.7910, 202.7426],
        [202.4477, 202.7183, 202.6023, 202.9066, 203.1989, 203.0901, 202.8171,
         203.0237, 202.7965, 202.7482],
        [202.1124, 202.3835, 202.2686, 202.5754, 202.8670, 202.7556, 202.4836,
         202.6907, 202.4628, 202.4140],
        [202.4232, 202.6930, 202.5788, 202.8848, 203.1765, 203.0672, 202.7925,
         203.0003, 202.7724, 202.7224],
        [202.4460, 202.7157, 202.6003, 202.9067, 203.1992, 203.0868, 202.8160,
         203.0229, 202.7939, 202.7464],
        [202.4418, 202

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.7388, 201.6929, 202.9063, 202.0440, 202.4788, 202.7399, 202.1844,
         202.5770, 202.5329, 202.0115],
        [202.8371, 201.7911, 203.0042, 202.1424, 202.5764, 202.8386, 202.2824,
         202.6748, 202.6313, 202.1102],
        [202.8252, 201.7803, 202.9931, 202.1304, 202.5657, 202.8267, 202.2711,
         202.6627, 202.6188, 202.0990],
        [202.8927, 201.8479, 203.0605, 202.1980, 202.6331, 202.8947, 202.3384,
         202.7295, 202.6861, 202.1669],
        [202.7533, 201.7086, 202.9187, 202.0591, 202.4921, 202.7551, 202.1991,
         202.5912, 202.5467, 202.0269],
        [202.9787, 201.9326, 203.1449, 202.2839, 202.7174, 202.9804, 202.4242,
         202.8167, 202.7733, 202.2520],
        [202.8073, 201.7627, 202.9755, 202.1124, 202.5483, 202.8087, 202.2535,
         202.6449, 202.6006, 202.0813],
        [202.6668, 201

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.4988, 202.6658, 202.6433, 202.0526, 202.6355, 202.8098, 202.5590,
         202.7915, 202.4504, 202.4688],
        [202.3489, 202.5174, 202.4920, 201.9009, 202.4878, 202.6607, 202.4093,
         202.6430, 202.3017, 202.3207],
        [202.7070, 202.8738, 202.8506, 202.2612, 202.8440, 203.0181, 202.7673,
         203.0011, 202.6600, 202.6775],
        [202.4662, 202.6348, 202.6106, 202.0198, 202.6049, 202.7778, 202.5262,
         202.7602, 202.4190, 202.4375],
        [202.2568, 202.4242, 202.4018, 201.8100, 202.3940, 202.5682, 202.3168,
         202.5492, 202.2081, 202.2267],
        [202.3976, 202.5631, 202.5416, 201.9508, 202.5330, 202.7088, 202.4580,
         202.6905, 202.3495, 202.3668],
        [202.4073, 202.5755, 202.5522, 201.9611, 202.5457, 202.7190, 202.4670,
         202.7015, 202.3603, 202.3780],
        [202.3770, 202

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.0025, 202.1779, 202.0905, 202.6082, 202.9417, 201.5687, 202.4229,
         201.6751, 202.1819, 203.3422],
        [202.2293, 202.4048, 202.3161, 202.8338, 203.1668, 201.7941, 202.6477,
         201.9021, 202.4077, 203.5677],
        [202.4008, 202.5760, 202.4889, 203.0051, 203.3367, 201.9650, 202.8188,
         202.0720, 202.5788, 203.7394],
        [202.3338, 202.5094, 202.4217, 202.9414, 203.2731, 201.9007, 202.7542,
         202.0066, 202.5141, 203.6745],
        [202.2410, 202.4163, 202.3284, 202.8449, 203.1776, 201.8053, 202.6591,
         201.9130, 202.4189, 203.5792],
        [202.0557, 202.2308, 202.1442, 202.6617, 202.9941, 201.6216, 202.4759,
         201.7261, 202.2352, 203.3956],
        [202.5138, 202.6896, 202.6008, 203.1203, 203.4521, 202.0799, 202.9329,
         202.1890, 202.6930, 203.8534],
        [202.3217, 202

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.4156, 202.2364, 202.5665, 202.2434, 201.8105, 202.3704, 202.7449,
         202.1865, 202.1025, 202.5592],
        [202.6727, 202.4933, 202.8224, 202.4958, 202.0688, 202.6280, 203.0005,
         202.4429, 202.3590, 202.8132],
        [202.3135, 202.1332, 202.4637, 202.1396, 201.7075, 202.2680, 202.6420,
         202.0833, 201.9987, 202.4564],
        [202.4663, 202.2857, 202.6156, 202.2904, 201.8613, 202.4214, 202.7942,
         202.2355, 202.1513, 202.6075],
        [202.5614, 202.3817, 202.7112, 202.3872, 201.9571, 202.5165, 202.8892,
         202.3321, 202.2483, 202.7033],
        [202.4410, 202.2613, 202.5912, 202.2647, 201.8354, 202.3953, 202.7684,
         202.2110, 202.1265, 202.5823],
        [202.5744, 202.3961, 202.7257, 202.4008, 201.9702, 202.5301, 202.9048,
         202.3453, 202.2613, 202.7175],
        [202.5455, 202

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.1027, 202.2139, 201.5393, 202.2372, 202.8592, 202.4875, 202.6978,
         202.1436, 202.5695, 202.3792],
        [203.2991, 202.4106, 201.7356, 202.4339, 203.0555, 202.6844, 202.8943,
         202.3401, 202.7660, 202.5765],
        [203.2290, 202.3421, 201.6651, 202.3647, 202.9832, 202.6156, 202.8244,
         202.2692, 202.6972, 202.5095],
        [203.0622, 202.1736, 201.4992, 202.1975, 202.8189, 202.4478, 202.6578,
         202.1034, 202.5301, 202.3396],
        [202.9066, 202.0194, 201.3449, 202.0440, 202.6639, 202.2947, 202.5042,
         201.9489, 202.3766, 202.1875],
        [203.1687, 202.2812, 201.6053, 202.3048, 202.9238, 202.5555, 202.7643,
         202.2093, 202.6378, 202.4488],
        [202.8467, 201.9585, 201.2844, 201.9830, 202.6036, 202.2333, 202.4431,
         201.8883, 202.3159, 202.1251],
        [202.8848, 201

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.7155, 201.9220, 201.6752, 202.3514, 202.8902, 202.2639, 202.4398,
         201.8901, 203.5135, 202.2974],
        [201.8296, 202.0359, 201.7904, 202.4649, 203.0047, 202.3786, 202.5526,
         202.0053, 203.6292, 202.4127],
        [201.7433, 201.9519, 201.7062, 202.3824, 202.9187, 202.2917, 202.4706,
         201.9196, 203.5439, 202.3276],
        [201.9635, 202.1686, 201.9230, 202.5982, 203.1374, 202.5104, 202.6867,
         202.1382, 203.7609, 202.5449],
        [201.4504, 201.6597, 201.4130, 202.0905, 202.6280, 202.0008, 202.1772,
         201.6266, 203.2510, 202.0345],
        [201.8292, 202.0341, 201.7885, 202.4641, 203.0050, 202.3775, 202.5508,
         202.0038, 203.6265, 202.4101],
        [201.9904, 202.1953, 201.9502, 202.6249, 203.1647, 202.5378, 202.7127,
         202.1654, 203.7884, 202.5720],
        [201.6554, 201

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.6795, 202.4821, 202.5593, 202.6944, 202.3954, 201.5362, 202.0250,
         202.5795, 201.8860, 202.6714],
        [201.7130, 202.5138, 202.5912, 202.7276, 202.4289, 201.5692, 202.0580,
         202.6121, 201.9188, 202.7044],
        [201.9068, 202.7091, 202.7857, 202.9212, 202.6211, 201.7627, 202.2523,
         202.8064, 202.1125, 202.8967],
        [201.3527, 202.1608, 202.2368, 202.3712, 202.0721, 201.2134, 201.7016,
         202.2558, 201.5625, 202.3494],
        [201.5021, 202.3073, 202.3837, 202.5192, 202.2196, 201.3608, 201.8498,
         202.4038, 201.7105, 202.4962],
        [201.7368, 202.5416, 202.6183, 202.7523, 202.4522, 201.5940, 202.0833,
         202.6380, 201.9442, 202.7286],
        [201.7065, 202.5138, 202.5893, 202.7236, 202.4238, 201.5657, 202.0545,
         202.6084, 201.9144, 202.7005],
        [201.6370, 202

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.2842, 202.7208, 202.5817, 202.4823, 202.6037, 201.7280, 202.4800,
         202.7299, 202.6308, 202.0743],
        [202.2228, 202.6575, 202.5167, 202.4158, 202.5424, 201.6642, 202.4163,
         202.6656, 202.5663, 202.0112],
        [202.3006, 202.7368, 202.5978, 202.4980, 202.6201, 201.7443, 202.4961,
         202.7459, 202.6468, 202.0906],
        [202.1362, 202.5699, 202.4292, 202.3287, 202.4542, 201.5772, 202.3288,
         202.5785, 202.4792, 201.9236],
        [202.1731, 202.6109, 202.4687, 202.3696, 202.4928, 201.6174, 202.3692,
         202.6188, 202.5188, 201.9627],
        [202.2097, 202.6437, 202.5036, 202.4039, 202.5273, 201.6520, 202.4031,
         202.6529, 202.5537, 201.9975],
        [202.3576, 202.7921, 202.6528, 202.5532, 202.6759, 201.8010, 202.5520,
         202.8016, 202.7026, 202.1463],
        [202.2741, 202

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.5754, 202.5465, 202.2540, 202.7857, 202.4078, 202.5525, 202.5151,
         202.1730, 202.3872, 203.0297],
        [201.8344, 202.8049, 202.5126, 203.0466, 202.6665, 202.8089, 202.7743,
         202.4323, 202.6447, 203.2911],
        [201.5193, 202.4908, 202.1977, 202.7329, 202.3535, 202.4953, 202.4604,
         202.1164, 202.3308, 202.9772],
        [201.2824, 202.2536, 201.9611, 202.4951, 202.1156, 202.2601, 202.2226,
         201.8792, 202.0942, 202.7381],
        [201.6138, 202.5844, 202.2918, 202.8268, 202.4476, 202.5892, 202.5546,
         202.2110, 202.4249, 203.0704],
        [201.6954, 202.6660, 202.3735, 202.9069, 202.5281, 202.6709, 202.6354,
         202.2931, 202.5064, 203.1510],
        [201.7445, 202.7164, 202.4228, 202.9557, 202.5782, 202.7203, 202.6855,
         202.3421, 202.5563, 203.2012],
        [201.6898, 202

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.6677, 201.7353, 201.7765, 202.9981, 202.0626, 202.3511, 202.1602,
         202.5337, 202.4058, 202.2363],
        [202.6521, 201.7222, 201.7637, 202.9830, 202.0491, 202.3389, 202.1459,
         202.5190, 202.3925, 202.2238],
        [202.5605, 201.6308, 201.6702, 202.8921, 201.9571, 202.2453, 202.0536,
         202.4270, 202.3005, 202.1305],
        [202.5712, 201.6417, 201.6827, 202.9026, 201.9689, 202.2580, 202.0652,
         202.4380, 202.3121, 202.1433],
        [202.7072, 201.7771, 201.8190, 203.0377, 202.1040, 202.3942, 202.2009,
         202.5741, 202.4474, 202.2790],
        [202.7683, 201.8382, 201.8786, 203.0989, 202.1637, 202.4534, 202.2613,
         202.6352, 202.5076, 202.3380],
        [202.7634, 201.8338, 201.8737, 203.0944, 202.1603, 202.4479, 202.2574,
         202.6287, 202.5034, 202.3338],
        [202.9087, 201

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.4426, 201.9474, 201.7608, 202.3579, 202.2849, 202.1491, 203.2987,
         202.2199, 203.1305, 202.0556],
        [201.4169, 201.9222, 201.7341, 202.3285, 202.2549, 202.1202, 203.2732,
         202.1941, 203.1031, 202.0268],
        [201.3260, 201.8305, 201.6435, 202.2409, 202.1682, 202.0306, 203.1830,
         202.1035, 203.0138, 201.9398],
        [201.3897, 201.8934, 201.7073, 202.3038, 202.2304, 202.0939, 203.2455,
         202.1671, 203.0771, 202.0016],
        [201.3355, 201.8398, 201.6528, 202.2491, 202.1748, 202.0408, 203.1910,
         202.1123, 203.0222, 201.9457],
        [201.3215, 201.8255, 201.6390, 202.2370, 202.1642, 202.0265, 203.1782,
         202.0990, 203.0096, 201.9355],
        [201.3989, 201.9027, 201.7166, 202.3131, 202.2399, 202.1030, 203.2549,
         202.1765, 203.0864, 202.0112],
        [201.3271, 201.8318, 201.6444, 202.2408, 202.1684, 202.0299, 203.1844,
         202.1050, 203.0144, 201

tensor([[201.6371, 202.3733, 201.8662, 201.9262, 202.1939, 203.2161, 202.3845,
         202.3912, 202.3064, 202.3775],
        [201.5470, 202.2845, 201.7751, 201.8360, 202.1044, 203.1252, 202.2946,
         202.3021, 202.2165, 202.2865],
        [201.6330, 202.3705, 201.8636, 201.9227, 202.1897, 203.2132, 202.3825,
         202.3880, 202.3031, 202.3743],
        [201.6342, 202.3717, 201.8620, 201.9236, 202.1915, 203.2117, 202.3813,
         202.3892, 202.3036, 202.3732],
        [201.5156, 202.2509, 201.7446, 201.8047, 202.0729, 203.0957, 202.2639,
         202.2702, 202.1860, 202.2581],
        [201.6167, 202.3524, 201.8441, 201.9058, 202.1742, 203.1949, 202.3636,
         202.3713, 202.2868, 202.3577],
        [201.6755, 202.4095, 201.9009, 201.9635, 202.2326, 203.2516, 202.4183,
         202.4279, 202.3437, 202.4138],
        [201.5681, 202.3036, 201.7947, 201.8567, 202.1257, 203.1455, 202.3136,
         202.3222, 202.2374, 202.3077],
        [201.6016, 202.3374, 201.8283, 201.8904,

tensor([[202.2750, 202.2986, 202.8184, 202.7946, 201.5756, 201.9629, 202.0363,
         201.8235, 202.6653, 202.1215],
        [202.2701, 202.2946, 202.8105, 202.7921, 201.5710, 201.9577, 202.0301,
         201.8191, 202.6620, 202.1184],
        [202.2409, 202.2643, 202.7836, 202.7604, 201.5413, 201.9284, 202.0021,
         201.7891, 202.6307, 202.0873],
        [202.2641, 202.2872, 202.8041, 202.7829, 201.5653, 201.9498, 202.0244,
         201.8121, 202.6528, 202.1116],
        [202.1810, 202.2050, 202.7228, 202.7025, 201.4811, 201.8688, 201.9415,
         201.7291, 202.5725, 202.0278],
        [202.2474, 202.2716, 202.7869, 202.7688, 201.5486, 201.9343, 202.0071,
         201.7961, 202.6385, 202.0960],
        [202.2534, 202.2761, 202.7916, 202.7722, 201.5546, 201.9378, 202.0130,
         201.8009, 202.6415, 202.1011],
        [202.3398, 202.3652, 202.8840, 202.8623, 201.6409, 202.0302, 202.1012,
         201.8903, 202.7331, 202.1880],
        [202.3456, 202.3676, 202.8850, 202.8644,

tensor([[201.8513, 201.7663, 202.2966, 202.3098, 201.8551, 202.2874, 202.2329,
         201.2962, 201.9094, 202.1006],
        [201.9765, 201.8927, 202.4223, 202.4353, 201.9803, 202.4124, 202.3570,
         201.4236, 202.0345, 202.2245],
        [202.0257, 201.9409, 202.4714, 202.4844, 202.0297, 202.4619, 202.4060,
         201.4726, 202.0836, 202.2740],
        [201.6370, 201.5520, 202.0821, 202.0954, 201.6404, 202.0733, 202.0202,
         201.0799, 201.6954, 201.8875],
        [202.1047, 202.0225, 202.5481, 202.5623, 202.1094, 202.5376, 202.4855,
         201.5495, 202.1625, 202.3531],
        [201.9740, 201.8900, 202.4182, 202.4322, 201.9784, 202.4089, 202.3550,
         201.4186, 202.0319, 202.2228],
        [202.1104, 202.0278, 202.5546, 202.5681, 202.1150, 202.5437, 202.4913,
         201.5560, 202.1683, 202.3589],
        [201.9597, 201.8763, 202.4053, 202.4182, 201.9634, 202.3948, 202.3411,
         201.4062, 202.0178, 202.2084],
        [201.8204, 201.7362, 202.2658, 202.2787,

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[203.0979, 202.1304, 202.3181, 201.9709, 201.4160, 201.7489, 202.0273,
         201.9486, 202.1240, 201.9365],
        [203.0014, 202.0351, 202.2206, 201.8728, 201.3176, 201.6531, 201.9289,
         201.8521, 202.0259, 201.8404],
        [203.3730, 202.4058, 202.5923, 202.2464, 201.6900, 202.0254, 202.3010,
         202.2236, 202.3981, 202.2113],
        [202.9636, 201.9957, 202.1824, 201.8353, 201.2802, 201.6141, 201.8911,
         201.8132, 201.9906, 201.8004],
        [202.8530, 201.8863, 202.0732, 201.7245, 201.1703, 201.5040, 201.7816,
         201.7041, 201.8786, 201.6920],
        [203.1161, 202.1500, 202.3338, 201.9872, 201.4311, 201.7683, 202.0425,
         201.9657, 202.1397, 201.9548],
        [203.0598, 202.0929, 202.2784, 201.9314, 201.3754, 201.7116, 201.9863,
         201.9104, 202.0853, 201.8974],
        [203.0212, 202.0529, 202.2414, 201.8938, 201.33

tensor([0, 0, 0, 0, 0, 9, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.9658, 201.8664, 201.4048, 202.1086, 202.2420, 202.0327, 201.9197,
         201.7850, 202.1006, 202.1184],
        [202.3243, 202.2237, 201.7638, 202.4655, 202.5991, 202.3898, 202.2796,
         202.1446, 202.4579, 202.4754],
        [201.9487, 201.8490, 201.3867, 202.0914, 202.2235, 202.0166, 201.9028,
         201.7673, 202.0816, 202.1004],
        [202.3050, 202.2044, 201.7445, 202.4452, 202.5790, 202.3698, 202.2602,
         202.1257, 202.4370, 202.4551],
        [201.7551, 201.6554, 201.1927, 201.8970, 202.0296, 201.8221, 201.7082,
         201.5737, 201.8868, 201.9062],
        [202.1110, 202.0104, 201.5504, 202.2510, 202.3855, 202.1747, 202.0650,
         201.9317, 202.2434, 202.2612],
        [201.9913, 201.8924, 201.4306, 202.1319, 202.2663, 202.0569, 201.9447,
         201.8119, 202.1224, 202.1417],
        [202.1538, 202.0533, 201.5923, 202.2973, 202.42

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.8045, 201.4191, 201.9357, 202.0466, 201.2407, 201.2055, 201.2719,
         202.1418, 201.5873, 201.8824],
        [201.7390, 201.3517, 201.8678, 201.9799, 201.1746, 201.1395, 201.2060,
         202.0752, 201.5211, 201.8170],
        [201.7790, 201.3955, 201.9098, 202.0224, 201.2182, 201.1819, 201.2467,
         202.1170, 201.5636, 201.8581],
        [201.8661, 201.4790, 201.9958, 202.1072, 201.3021, 201.2665, 201.3331,
         202.2024, 201.6480, 201.9438],
        [201.8679, 201.4818, 201.9976, 202.1087, 201.3058, 201.2673, 201.3336,
         202.2041, 201.6497, 201.9441],
        [201.8971, 201.5122, 202.0290, 202.1396, 201.3337, 201.2985, 201.3649,
         202.2346, 201.6802, 201.9752],
        [201.6366, 201.2491, 201.7649, 201.8773, 201.0722, 201.0368, 201.1031,
         201.9726, 201.4185, 201.7143],
        [201.8279, 201

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.8836, 201.6640, 201.8907, 202.4555, 201.7508, 202.5863, 202.0908,
         202.1817, 202.0930, 201.9856],
        [201.9870, 201.7666, 201.9938, 202.5590, 201.8534, 202.6887, 202.1941,
         202.2849, 202.1968, 202.0895],
        [201.9204, 201.6988, 201.9270, 202.4943, 201.7844, 202.6237, 202.1277,
         202.2183, 202.1302, 202.0224],
        [202.0045, 201.7850, 202.0119, 202.5769, 201.8711, 202.7076, 202.2116,
         202.3023, 202.2140, 202.1062],
        [202.0478, 201.8264, 202.0545, 202.6218, 201.9112, 202.7513, 202.2547,
         202.3445, 202.2573, 202.1491],
        [201.9242, 201.7041, 201.9310, 202.4957, 201.7911, 202.6259, 202.1311,
         202.2215, 202.1335, 202.0263],
        [201.9322, 201.7112, 201.9391, 202.5050, 201.7965, 202.6352, 202.1385,
         202.2268, 202.1402, 202.0325],
        [201.9836, 201

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.8808, 201.1323, 202.0304, 201.8304, 201.7703, 202.1434, 201.5926,
         202.0011, 201.6404, 201.4958],
        [201.9558, 201.2069, 202.1045, 201.9045, 201.8430, 202.2169, 201.6658,
         202.0762, 201.7144, 201.5698],
        [201.9161, 201.1646, 202.0652, 201.8659, 201.8041, 202.1769, 201.6272,
         202.0355, 201.6749, 201.5283],
        [201.9070, 201.1559, 202.0544, 201.8537, 201.7929, 202.1671, 201.6160,
         202.0268, 201.6642, 201.5196],
        [202.2025, 201.4546, 202.3512, 202.1523, 202.0896, 202.4637, 201.9130,
         202.3230, 201.9619, 201.8170],
        [201.6328, 200.8808, 201.7816, 201.5802, 201.5216, 201.8944, 201.3438,
         201.7521, 201.3907, 201.2456],
        [202.0446, 201.2969, 202.1928, 201.9925, 201.9321, 202.3063, 201.7549,
         202.1651, 201.8034, 201.6599],
        [201.9012, 201

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.7238, 201.5921, 201.2189, 200.8719, 201.9092, 201.9986, 201.4722,
         201.8975, 202.5511, 201.6497],
        [201.9649, 201.8336, 201.4597, 201.1131, 202.1504, 202.2391, 201.7135,
         202.1390, 202.7925, 201.8903],
        [202.0131, 201.8819, 201.5077, 201.1613, 202.1984, 202.2875, 201.7614,
         202.1869, 202.8405, 201.9364],
        [201.9557, 201.8258, 201.4520, 201.1044, 202.1428, 202.2304, 201.7044,
         202.1306, 202.7843, 201.8808],
        [201.8574, 201.7282, 201.3541, 201.0067, 202.0450, 202.1328, 201.6067,
         202.0338, 202.6867, 201.7841],
        [201.9132, 201.7810, 201.4068, 201.0612, 202.0974, 202.1875, 201.6617,
         202.0868, 202.7400, 201.8383],
        [201.6148, 201.4838, 201.1108, 200.7634, 201.8011, 201.8902, 201.3633,
         201.7892, 202.4426, 201.5412],
        [201.6860, 201

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.7742, 201.7885, 201.6497, 201.3822, 201.8405, 201.7831, 200.9007,
         201.6756, 201.4977, 201.4614],
        [202.0342, 202.0471, 201.9089, 201.6415, 202.1004, 202.0403, 201.1586,
         201.9343, 201.7570, 201.7223],
        [201.7471, 201.7588, 201.6211, 201.3536, 201.8122, 201.7524, 200.8708,
         201.6471, 201.4693, 201.4330],
        [201.9153, 201.9275, 201.7905, 201.5228, 201.9816, 201.9210, 201.0399,
         201.8165, 201.6391, 201.6022],
        [201.8843, 201.8961, 201.7589, 201.4913, 201.9501, 201.8891, 201.0081,
         201.7846, 201.6072, 201.5712],
        [201.8019, 201.8157, 201.6772, 201.4097, 201.8681, 201.8101, 200.9279,
         201.7031, 201.5253, 201.4890],
        [201.9445, 201.9561, 201.8184, 201.5510, 202.0100, 201.9487, 201.0676,
         201.8438, 201.6665, 201.6318],
        [202.1325, 202

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.6982, 201.5655, 202.0054, 201.7298, 201.8944, 201.3712, 201.8105,
         200.9719, 201.7052, 201.6440],
        [201.8622, 201.7256, 202.1693, 201.8925, 202.0588, 201.5340, 201.9753,
         201.1349, 201.8689, 201.8061],
        [201.8497, 201.7157, 202.1584, 201.8810, 202.0469, 201.5245, 201.9620,
         201.1236, 201.8566, 201.7961],
        [202.0591, 201.9237, 202.3658, 202.0897, 202.2548, 201.7318, 202.1705,
         201.3318, 202.0663, 202.0034],
        [201.7507, 201.6151, 202.0564, 201.7813, 201.9463, 201.4211, 201.8631,
         201.0228, 201.7576, 201.6938],
        [201.8147, 201.6810, 202.1237, 201.8462, 202.0120, 201.4899, 201.9269,
         201.0885, 201.8214, 201.7613],
        [201.8182, 201.6850, 202.1251, 201.8502, 202.0137, 201.4911, 201.9281,
         201.0900, 201.8248, 201.7625],
        [201.8172, 201

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.3748, 201.8583, 202.0372, 202.7055, 201.5109, 202.0534, 202.0977,
         201.8792, 201.9594, 201.9174],
        [201.1743, 201.6571, 201.8338, 202.5050, 201.3110, 201.8530, 201.8944,
         201.6797, 201.7588, 201.7138],
        [201.2052, 201.6882, 201.8670, 202.5355, 201.3410, 201.8831, 201.9278,
         201.7097, 201.7896, 201.7474],
        [201.3655, 201.8484, 202.0270, 202.6964, 201.5021, 202.0440, 202.0869,
         201.8715, 201.9503, 201.9069],
        [201.2900, 201.7736, 201.9506, 202.6211, 201.4268, 201.9700, 202.0123,
         201.7941, 201.8742, 201.8313],
        [201.0474, 201.5297, 201.7081, 202.3775, 201.1833, 201.7254, 201.7691,
         201.5523, 201.6317, 201.5886],
        [201.5488, 202.0329, 202.2105, 202.8799, 201.6857, 202.2276, 202.2706,
         202.0536, 202.1338, 202.0903],
        [201.0862, 201

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.1966, 201.2937, 201.6229, 201.4837, 200.9121, 201.4274, 201.3172,
         201.6219, 201.2435, 201.9983],
        [201.5571, 201.6508, 201.9837, 201.8464, 201.2742, 201.7872, 201.6780,
         201.9823, 201.6043, 202.3599],
        [201.4444, 201.5403, 201.8712, 201.7328, 201.1613, 201.6758, 201.5657,
         201.8704, 201.4921, 202.2473],
        [201.4067, 201.5032, 201.8321, 201.6931, 201.1227, 201.6364, 201.5265,
         201.8309, 201.4533, 202.2069],
        [201.4214, 201.5190, 201.8477, 201.7087, 201.1379, 201.6530, 201.5422,
         201.8467, 201.4691, 202.2232],
        [201.3948, 201.4907, 201.8194, 201.6799, 201.1105, 201.6232, 201.5141,
         201.8191, 201.4408, 202.1938],
        [201.5466, 201.6430, 201.9721, 201.8332, 201.2633, 201.7768, 201.6668,
         201.9716, 201.5938, 202.3473],
        [201.2670, 201.3614, 201.6929, 201.5547, 200.9827, 201.4959, 201.3869,
         201.6915, 201.3130, 202

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.9065, 201.9448, 201.7674, 201.7432, 201.6302, 201.4693, 201.6553,
         202.4366, 201.6007, 201.9136],
        [201.0476, 202.0862, 201.9095, 201.8851, 201.7717, 201.6111, 201.7965,
         202.5771, 201.7435, 202.0547],
        [201.1227, 202.1603, 201.9820, 201.9582, 201.8462, 201.6853, 201.8709,
         202.6512, 201.8161, 202.1292],
        [200.7997, 201.8367, 201.6579, 201.6334, 201.5217, 201.3602, 201.5475,
         202.3283, 201.4914, 201.8059],
        [200.7281, 201.7655, 201.5891, 201.5641, 201.4505, 201.2894, 201.4770,
         202.2577, 201.4239, 201.7329],
        [200.8804, 201.9180, 201.7391, 201.7150, 201.6032, 201.4418, 201.6286,
         202.4099, 201.5723, 201.8873],
        [200.7098, 201.7472, 201.5695, 201.5445, 201.4319, 201.2706, 201.4581,
         202.2389, 201.4034, 201.7156],
        [200.8031, 201.8400, 201.6629, 201.6375, 201.5248, 201.3639, 201.5511,
         202.3306, 201.4982, 201

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.5740, 201.8750, 201.0524, 201.4229, 200.8414, 201.8101, 201.5984,
         201.6261, 201.6267, 201.3384],
        [201.6891, 201.9907, 201.1690, 201.5386, 200.9571, 201.9240, 201.7140,
         201.7422, 201.7410, 201.4551],
        [201.4956, 201.7960, 200.9741, 201.3436, 200.7632, 201.7325, 201.5199,
         201.5470, 201.5473, 201.2598],
        [201.7362, 202.0377, 201.2159, 201.5855, 201.0040, 201.9707, 201.7611,
         201.7892, 201.7881, 201.5018],
        [201.6569, 201.9534, 201.1342, 201.5000, 200.9219, 201.8909, 201.6810,
         201.7046, 201.7046, 201.4172],
        [201.7760, 202.0737, 201.2543, 201.6206, 201.0415, 202.0090, 201.8004,
         201.8252, 201.8245, 201.5378],
        [201.8987, 202.1980, 201.3769, 201.7453, 201.1652, 202.1320, 201.9226,
         201.9496, 201.9494, 201.6619],
        [201.7736, 202

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.8957, 201.7679, 201.6871, 201.5292, 202.0599, 202.2846, 201.7627,
         201.8795, 200.9239, 201.3657],
        [200.8615, 201.7337, 201.6521, 201.4957, 202.0279, 202.2497, 201.7285,
         201.8471, 200.8893, 201.3295],
        [200.9655, 201.8378, 201.7563, 201.5989, 202.1297, 202.3532, 201.8312,
         201.9485, 200.9926, 201.4357],
        [200.8373, 201.7123, 201.6294, 201.4726, 202.0034, 202.2273, 201.7032,
         201.8221, 200.8661, 201.3097],
        [200.6316, 201.5075, 201.4252, 201.2678, 201.7988, 202.0241, 201.4996,
         201.6185, 200.6627, 201.1044],
        [200.8493, 201.7229, 201.6406, 201.4839, 202.0152, 202.2379, 201.7153,
         201.8339, 200.8771, 201.3198],
        [200.8389, 201.7117, 201.6309, 201.4735, 202.0044, 202.2299, 201.7077,
         201.8245, 200.8692, 201.3088],
        [200.9086, 201

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.2983, 201.5149, 201.6448, 201.6304, 200.9231, 200.8801, 202.1570,
         200.7747, 201.3262, 201.6124],
        [201.4122, 201.6291, 201.7581, 201.7448, 201.0365, 200.9946, 202.2695,
         200.8892, 201.4402, 201.7254],
        [201.6057, 201.8234, 201.9530, 201.9371, 201.2308, 201.1917, 202.4633,
         201.0826, 201.6339, 201.9208],
        [201.1931, 201.4104, 201.5399, 201.5252, 200.8174, 200.7751, 202.0517,
         200.6699, 201.2212, 201.5076],
        [201.3484, 201.5659, 201.6950, 201.6803, 200.9725, 200.9327, 202.2048,
         200.8250, 201.3764, 201.6619],
        [201.4405, 201.6570, 201.7875, 201.7719, 201.0657, 201.0240, 202.2990,
         200.9165, 201.4684, 201.7551],
        [201.3378, 201.5543, 201.6847, 201.6696, 200.9631, 200.9198, 202.1973,
         200.8141, 201.3658, 201.6526],
        [201.1760, 201.3933, 201.5225, 201.5085, 200.80

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.6243, 201.3842, 201.4340, 201.4043, 200.9991, 202.0851, 201.8197,
         201.6150, 200.7881, 201.7560],
        [201.5638, 201.3226, 201.3733, 201.3430, 200.9393, 202.0239, 201.7584,
         201.5545, 200.7272, 201.6955],
        [201.4865, 201.2469, 201.2963, 201.2664, 200.8617, 201.9469, 201.6826,
         201.4780, 200.6499, 201.6182],
        [201.6004, 201.3607, 201.4102, 201.3802, 200.9760, 202.0600, 201.7972,
         201.5920, 200.7640, 201.7309],
        [201.4974, 201.2574, 201.3076, 201.2775, 200.8721, 201.9581, 201.6934,
         201.4881, 200.6614, 201.6297],
        [201.6608, 201.4185, 201.4707, 201.4397, 201.0365, 202.1194, 201.8560,
         201.6511, 200.8251, 201.7912],
        [201.6914, 201.4510, 201.5012, 201.4710, 201.0670, 202.1503, 201.8880,
         201.6826, 200.8553, 201.8211],
        [201.6950, 201.4539, 201.5043, 201.4740, 201.0716, 202.1537, 201.8907,
         201.6866, 200.8582, 201

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.7533, 200.9456, 201.4803, 201.3123, 201.6682, 201.3275, 201.1918,
         200.7985, 201.5599, 201.1325],
        [201.6113, 200.8022, 201.3374, 201.1707, 201.5264, 201.1848, 201.0491,
         200.6574, 201.4177, 200.9893],
        [201.5645, 200.7575, 201.2909, 201.1235, 201.4797, 201.1375, 201.0027,
         200.6097, 201.3720, 200.9432],
        [201.8526, 201.0438, 201.5803, 201.4124, 201.7675, 201.4254, 201.2911,
         200.8969, 201.6591, 201.2325],
        [201.9365, 201.1313, 201.6638, 201.4972, 201.8543, 201.5119, 201.3788,
         200.9841, 201.7462, 201.3177],
        [201.8816, 201.0769, 201.6087, 201.4411, 201.7986, 201.4580, 201.3232,
         200.9289, 201.6905, 201.2622],
        [201.7702, 200.9613, 201.4968, 201.3306, 201.6863, 201.3443, 201.2097,
         200.8171, 201.5775, 201.1493],
        [201.5098, 200

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.6224, 200.3233, 200.6357, 200.4888, 200.8211, 201.3436, 201.0179,
         201.3908, 201.4544, 201.1993],
        [201.9402, 200.6421, 200.9562, 200.8089, 201.1409, 201.6628, 201.3393,
         201.7093, 201.7739, 201.5174],
        [201.8247, 200.5260, 200.8394, 200.6920, 201.0245, 201.5466, 201.2213,
         201.5932, 201.6573, 201.4018],
        [201.6541, 200.3550, 200.6672, 200.5201, 200.8523, 201.3749, 201.0498,
         201.4217, 201.4858, 201.2308],
        [201.9170, 200.6169, 200.9299, 200.7803, 201.1161, 201.6352, 201.3108,
         201.6804, 201.7478, 201.4940],
        [201.6712, 200.3720, 200.6844, 200.5370, 200.8698, 201.3916, 201.0670,
         201.4385, 201.5032, 201.2481],
        [201.8200, 200.5216, 200.8351, 200.6884, 201.0196, 201.5431, 201.2172,
         201.5899, 201.6528, 201.3969],
        [202.0330, 200

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.6841, 201.4911, 201.7606, 202.0889, 201.3898, 201.4697, 201.6958,
         201.5749, 201.2568, 201.9589],
        [201.7384, 201.5472, 201.8144, 202.1430, 201.4461, 201.5250, 201.7512,
         201.6289, 201.3131, 202.0139],
        [201.7165, 201.5243, 201.7954, 202.1232, 201.4239, 201.5040, 201.7291,
         201.6067, 201.2885, 201.9940],
        [201.6240, 201.4317, 201.6991, 202.0279, 201.3301, 201.4096, 201.6360,
         201.5141, 201.1976, 201.8984],
        [201.5845, 201.3914, 201.6597, 201.9885, 201.2898, 201.3697, 201.5958,
         201.4751, 201.1576, 201.8585],
        [201.6746, 201.4819, 201.7534, 202.0814, 201.3815, 201.4618, 201.6867,
         201.5650, 201.2465, 201.9517],
        [201.7368, 201.5433, 201.8117, 202.1396, 201.4411, 201.5211, 201.7485,
         201.6256, 201.3075, 202.0106],
        [201.6372, 201.4432, 201.7115, 202.0396, 201.34

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.7259, 201.5322, 201.7774, 202.0551, 202.1844, 201.4059, 201.7225,
         201.6614, 201.1591, 201.5959],
        [201.4657, 201.2702, 201.5205, 201.7944, 201.9231, 201.1454, 201.4628,
         201.4023, 200.8998, 201.3346],
        [201.5904, 201.3955, 201.6435, 201.9183, 202.0482, 201.2704, 201.5859,
         201.5247, 201.0235, 201.4598],
        [201.6112, 201.4163, 201.6642, 201.9391, 202.0691, 201.2913, 201.6067,
         201.5455, 201.0444, 201.4807],
        [201.5296, 201.3360, 201.5833, 201.8578, 201.9872, 201.2098, 201.5256,
         201.4650, 200.9637, 201.3994],
        [201.4381, 201.2430, 201.4901, 201.7666, 201.8956, 201.1169, 201.4345,
         201.3730, 200.8696, 201.3063],
        [201.8636, 201.6693, 201.9164, 202.1934, 202.3224, 201.5442, 201.8611,
         201.8006, 201.2983, 201.7342],
        [201.5213, 201.3256, 201.5737, 201.8504, 201.97

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.4386, 201.5633, 201.2753, 201.3601, 201.5335, 201.6958, 201.6433,
         201.8953, 201.6206, 200.5033],
        [201.6955, 201.8198, 201.5323, 201.6146, 201.7888, 201.9538, 201.8985,
         202.1520, 201.8752, 200.7587],
        [201.3815, 201.5063, 201.2180, 201.3018, 201.4758, 201.6388, 201.5855,
         201.8382, 201.5632, 200.4458],
        [201.6808, 201.8054, 201.5174, 201.5994, 201.7739, 201.9390, 201.8836,
         202.1374, 201.8605, 200.7439],
        [201.3698, 201.4932, 201.2052, 201.2874, 201.4618, 201.6269, 201.5705,
         201.8257, 201.5473, 200.4305],
        [201.4562, 201.5797, 201.2919, 201.3730, 201.5480, 201.7136, 201.6573,
         201.9120, 201.6346, 200.5177],
        [201.0871, 201.2100, 200.9224, 201.0036, 201.1788, 201.3428, 201.2886,
         201.5421, 201.2673, 200.1492],
        [201.4428, 201

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.7236, 201.1302, 201.2827, 201.0485, 201.0098, 200.5890, 201.3217,
         201.1675, 201.3557, 201.3079],
        [200.8019, 201.2093, 201.3609, 201.1275, 201.0886, 200.6681, 201.3994,
         201.2462, 201.4342, 201.3857],
        [200.6391, 201.0468, 201.1987, 200.9642, 200.9256, 200.5036, 201.2375,
         201.0842, 201.2729, 201.2232],
        [200.4064, 200.8124, 200.9665, 200.7307, 200.6916, 200.2687, 201.0069,
         200.8509, 201.0403, 200.9921],
        [200.7207, 201.1274, 201.2802, 201.0448, 201.0071, 200.5845, 201.3185,
         201.1654, 201.3542, 201.3043],
        [200.6902, 201.0956, 201.2493, 201.0135, 200.9759, 200.5542, 201.2882,
         201.1334, 201.3223, 201.2746],
        [200.5775, 200.9870, 201.1370, 200.9045, 200.8646, 200.4438, 201.1766,
         201.0232, 201.2116, 201.1622],
        [200.7566, 201.1611, 201.3163, 201.0803, 201.04

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.6916, 200.7733, 200.7022, 200.8098, 200.7375, 200.9398, 200.7831,
         200.8649, 200.0970, 201.1042],
        [201.0883, 201.1701, 201.0963, 201.2054, 201.1333, 201.3378, 201.1784,
         201.2615, 200.4933, 201.5010],
        [200.9921, 201.0738, 201.0011, 201.1092, 201.0374, 201.2414, 201.0823,
         201.1654, 200.3970, 201.4045],
        [200.9503, 201.0323, 200.9586, 201.0677, 200.9954, 201.1996, 201.0406,
         201.1234, 200.3554, 201.3632],
        [200.9910, 201.0726, 201.0024, 201.1105, 201.0381, 201.2387, 201.0832,
         201.1646, 200.3963, 201.4029],
        [200.8003, 200.8814, 200.8107, 200.9178, 200.8459, 201.0486, 200.8914,
         200.9737, 200.2054, 201.2126],
        [200.8673, 200.9490, 200.8753, 200.9845, 200.9120, 201.1163, 200.9577,
         201.0405, 200.2725, 201.2805],
        [200.9177, 200

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.6691, 201.1041, 200.9454, 201.0073, 201.1344, 200.6658, 200.8471,
         200.7607, 200.3857, 201.7394],
        [200.9120, 201.3476, 201.1879, 201.2510, 201.3788, 200.9093, 201.0914,
         201.0044, 200.6284, 201.9830],
        [200.8429, 201.2757, 201.1150, 201.1791, 201.3071, 200.8374, 201.0196,
         200.9329, 200.5567, 201.9119],
        [200.9500, 201.3850, 201.2249, 201.2885, 201.4166, 200.9467, 201.1291,
         201.0421, 200.6657, 202.0207],
        [201.2407, 201.6744, 201.5139, 201.5785, 201.7052, 201.2365, 201.4183,
         201.3319, 200.9564, 202.3106],
        [201.0419, 201.4748, 201.3144, 201.3786, 201.5054, 201.0367, 201.2182,
         201.1321, 200.7569, 202.1110],
        [200.9275, 201.3615, 201.2016, 201.2648, 201.3918, 200.9228, 201.1041,
         201.0181, 200.6434, 201.9971],
        [200.8448, 201.2779, 201.1176, 201.1816, 201.30

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.3050, 201.1648, 201.3156, 200.7825, 200.8807, 200.9758, 201.1906,
         200.3863, 201.0299, 200.4202],
        [201.2660, 201.1252, 201.2766, 200.7452, 200.8431, 200.9389, 201.1511,
         200.3473, 200.9912, 200.3831],
        [201.3006, 201.1604, 201.3114, 200.7777, 200.8762, 200.9719, 201.1866,
         200.3821, 201.0258, 200.4154],
        [201.3707, 201.2307, 201.3806, 200.8512, 200.9487, 201.0418, 201.2548,
         200.4520, 201.0956, 200.4894],
        [201.0572, 200.9162, 201.0680, 200.5347, 200.6333, 200.7298, 200.9429,
         200.1381, 200.7824, 200.1719],
        [201.4173, 201.2782, 201.4272, 200.8952, 200.9931, 201.0862, 201.3023,
         200.4982, 201.1417, 200.5334],
        [201.3452, 201.2050, 201.3554, 200.8230, 200.9216, 201.0155, 201.2305,
         200.4272, 201.0704, 200.4604],
        [201.3699, 201.2292, 201.3811, 200.8472, 200.94

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[202.1397, 201.5054, 200.9867, 201.7032, 201.0005, 201.0323, 201.4265,
         200.6214, 200.4451, 201.1893],
        [202.1723, 201.5380, 201.0198, 201.7361, 201.0328, 201.0651, 201.4592,
         200.6539, 200.4784, 201.2225],
        [202.2389, 201.6012, 201.0866, 201.7997, 201.0986, 201.1302, 201.5260,
         200.7199, 200.5465, 201.2876],
        [202.1614, 201.5243, 201.0075, 201.7220, 201.0222, 201.0525, 201.4478,
         200.6428, 200.4669, 201.2086],
        [202.2017, 201.5677, 201.0504, 201.7662, 201.0615, 201.0948, 201.4893,
         200.6831, 200.5088, 201.2533],
        [202.0055, 201.3687, 200.8517, 201.5676, 200.8678, 200.8968, 201.2916,
         200.4879, 200.3105, 201.0535],
        [202.1405, 201.5033, 200.9868, 201.7012, 201.0012, 201.0315, 201.4270,
         200.6219, 200.4460, 201.1878],
        [201.9973, 201

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.0384, 201.1899, 200.8023, 201.0307, 200.8758, 200.7567, 200.3252,
         201.0354, 201.2238, 201.7007],
        [201.0736, 201.2246, 200.8368, 201.0651, 200.9108, 200.7897, 200.3611,
         201.0673, 201.2598, 201.7347],
        [201.1302, 201.2811, 200.8933, 201.1215, 200.9672, 200.8462, 200.4174,
         201.1239, 201.3158, 201.7914],
        [201.0035, 201.1528, 200.7649, 200.9922, 200.8392, 200.7207, 200.2887,
         200.9997, 201.1870, 201.6667],
        [201.1784, 201.3298, 200.9426, 201.1703, 201.0160, 200.8964, 200.4656,
         201.1746, 201.3637, 201.8405],
        [200.8295, 200.9813, 200.5928, 200.8227, 200.6670, 200.5470, 200.1168,
         200.8253, 201.0166, 201.4908],
        [200.8832, 201.0332, 200.6449, 200.8732, 200.7196, 200.6000, 200.1696,
         200.8785, 201.0688, 201.5455],
        [201.0307, 201.1803, 200.7925, 201.0199, 200.86

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.2251, 201.0593, 200.9348, 201.7356, 201.1160, 201.0208, 201.0373,
         200.5094, 201.0875, 201.0341],
        [201.1429, 200.9775, 200.8539, 201.6528, 201.0357, 200.9380, 200.9540,
         200.4272, 201.0061, 200.9536],
        [201.1663, 200.9996, 200.8755, 201.6747, 201.0572, 200.9621, 200.9772,
         200.4490, 201.0281, 200.9750],
        [201.1119, 200.9463, 200.8230, 201.6210, 201.0050, 200.9070, 200.9220,
         200.3958, 200.9751, 200.9219],
        [201.3231, 201.1573, 201.0331, 201.8329, 201.2142, 201.1182, 201.1350,
         200.6070, 201.1855, 201.1329],
        [201.1902, 201.0234, 200.8993, 201.6986, 201.0809, 200.9861, 201.0012,
         200.4729, 201.0519, 200.9982],
        [201.0135, 200.8491, 200.7253, 201.5254, 200.9071, 200.8088, 200.8259,
         200.2993, 200.8776, 200.8250],
        [200.8465, 200.6810, 200.5570, 201.3572, 200.73

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.9692, 200.6428, 200.5551, 200.8554, 200.0518, 201.4706, 200.6662,
         200.6374, 200.9496, 201.7787],
        [201.3270, 201.0019, 200.9158, 201.2124, 200.4102, 201.8279, 201.0236,
         200.9953, 201.3053, 202.1361],
        [201.1980, 200.8726, 200.7858, 201.0842, 200.2815, 201.6987, 200.8954,
         200.8670, 201.1795, 202.0091],
        [201.1887, 200.8629, 200.7746, 201.0739, 200.2713, 201.6899, 200.8850,
         200.8564, 201.1673, 201.9964],
        [201.2131, 200.8875, 200.8016, 201.0995, 200.2958, 201.7152, 200.9098,
         200.8811, 201.1918, 202.0212],
        [201.1356, 200.8100, 200.7237, 201.0216, 200.2189, 201.6364, 200.8328,
         200.8045, 201.1158, 201.9465],
        [201.2488, 200.9235, 200.8367, 201.1344, 200.3319, 201.7498, 200.9455,
         200.9172, 201.2282, 202.0583],
        [201.0744, 200.7483, 200.6602, 200.9598, 200.15

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.1552, 201.3365, 201.1176, 200.9897, 200.4122, 200.9883, 201.5681,
         201.1993, 201.1531, 201.8148],
        [201.1563, 201.3383, 201.1182, 200.9913, 200.4132, 200.9896, 201.5693,
         201.2009, 201.1534, 201.8147],
        [200.8322, 201.0153, 200.7951, 200.6670, 200.0877, 200.6649, 201.2464,
         200.8770, 200.8288, 201.4902],
        [201.0730, 201.2551, 201.0355, 200.9079, 200.3295, 200.9059, 201.4870,
         201.1173, 201.0698, 201.7315],
        [201.0068, 201.1882, 200.9694, 200.8409, 200.2636, 200.8395, 201.4196,
         201.0514, 201.0042, 201.6658],
        [200.9825, 201.1633, 200.9459, 200.8162, 200.2389, 200.8146, 201.3964,
         201.0262, 200.9796, 201.6420],
        [200.7454, 200.9293, 200.7081, 200.5804, 200.0002, 200.5783, 201.1593,
         200.7903, 200.7425, 201.4036],
        [201.0760, 201

tensor([5, 6, 6, 5, 6, 5, 6, 5, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.0440, 200.7942, 201.0325, 200.8574, 201.4171, 201.6879, 200.8674,
         201.7802, 200.8502, 200.7186],
        [201.0212, 200.7730, 201.0121, 200.8353, 201.3929, 201.6657, 200.8465,
         201.7579, 200.8279, 200.6945],
        [201.0447, 200.7960, 201.0333, 200.8555, 201.4163, 201.6885, 200.8718,
         201.7797, 200.8486, 200.7192],
        [201.0296, 200.7799, 201.0172, 200.8411, 201.4022, 201.6730, 200.8545,
         201.7647, 200.8342, 200.7048],
        [201.1136, 200.8658, 201.1044, 200.9266, 201.4853, 201.7582, 200.9395,
         201.8504, 200.9189, 200.7868],
        [201.1838, 200.9355, 201.1730, 200.9961, 201.5555, 201.8279, 201.0101,
         201.9193, 200.9887, 200.8583],
        [200.9044, 200.6555, 200.8945, 200.7175, 201.2762, 201.5486, 200.7294,
         201.6408, 200.7104, 200.5780],
        [201.2643, 201.0160, 201.2549, 201.0800, 201.63

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.7754, 200.7463, 200.6085, 200.5021, 200.7163, 200.5995, 200.7721,
         201.1939, 201.0093, 200.6840],
        [200.6624, 200.6327, 200.4942, 200.3868, 200.6027, 200.4850, 200.6577,
         201.0812, 200.8953, 200.5697],
        [200.8171, 200.7881, 200.6501, 200.5442, 200.7578, 200.6417, 200.8142,
         201.2354, 201.0515, 200.7262],
        [200.8318, 200.8026, 200.6635, 200.5590, 200.7722, 200.6566, 200.8289,
         201.2498, 201.0671, 200.7417],
        [200.7759, 200.7465, 200.6077, 200.5022, 200.7164, 200.5999, 200.7723,
         201.1943, 201.0103, 200.6848],
        [200.5186, 200.4896, 200.3516, 200.2427, 200.4587, 200.3415, 200.5141,
         200.9370, 200.7516, 200.4255],
        [200.7436, 200.7149, 200.5782, 200.4692, 200.6836, 200.5675, 200.7403,
         201.1618, 200.9765, 200.6512],
        [200.8494, 200.8191, 200.6808, 200.5730, 200.78

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.7931, 200.6115, 200.0843, 201.0464, 200.9365, 200.4176, 200.5866,
         200.9529, 200.8503, 200.7969],
        [200.8108, 200.6313, 200.1057, 201.0655, 200.9579, 200.4361, 200.6062,
         200.9731, 200.8717, 200.8163],
        [200.6450, 200.4642, 199.9377, 200.8988, 200.7901, 200.2691, 200.4389,
         200.8054, 200.7031, 200.6496],
        [200.9421, 200.7595, 200.2324, 201.1949, 201.0849, 200.5670, 200.7350,
         201.1015, 200.9990, 200.9449],
        [200.6622, 200.4813, 199.9545, 200.9160, 200.8066, 200.2864, 200.4561,
         200.8226, 200.7202, 200.6665],
        [200.9077, 200.7247, 200.1974, 201.1603, 201.0498, 200.5324, 200.7003,
         201.0667, 200.9641, 200.9101],
        [201.0033, 200.8209, 200.2940, 201.2562, 201.1464, 200.6284, 200.7966,
         201.1629, 201.0606, 201.0063],
        [201.0721, 200.8929, 200.3660, 201.3269, 201.21

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.7375, 200.8914, 201.0671, 201.0672, 200.1863, 200.8175, 201.4400,
         201.7102, 200.7720, 200.8232],
        [200.4901, 200.6406, 200.8159, 200.8154, 199.9345, 200.5670, 201.1870,
         201.4590, 200.5198, 200.5709],
        [200.4490, 200.6027, 200.7772, 200.7773, 199.8956, 200.5282, 201.1492,
         201.4203, 200.4810, 200.5327],
        [200.6274, 200.7785, 200.9555, 200.9550, 200.0734, 200.7056, 201.3278,
         201.5983, 200.6598, 200.7110],
        [200.6630, 200.8147, 200.9888, 200.9895, 200.1102, 200.7405, 201.3614,
         201.6328, 200.6942, 200.7444],
        [200.6082, 200.7577, 200.9330, 200.9322, 200.0525, 200.6845, 201.3036,
         201.5762, 200.6373, 200.6878],
        [200.6984, 200.8506, 201.0260, 201.0246, 200.1447, 200.7773, 201.3960,
         201.6686, 200.7298, 200.7810],
        [200.7673, 200.9212, 201.0949, 201.0947, 200.21

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.8603, 201.1211, 200.7978, 200.5963, 200.8332, 200.9916, 200.7025,
         200.8525, 200.4058, 201.3191],
        [200.9091, 201.1701, 200.8463, 200.6451, 200.8818, 201.0403, 200.7515,
         200.9012, 200.4547, 201.3677],
        [200.7509, 201.0106, 200.6880, 200.4881, 200.7234, 200.8817, 200.5931,
         200.7444, 200.2971, 201.2089],
        [200.9888, 201.2481, 200.9253, 200.7252, 200.9605, 201.1181, 200.8304,
         200.9816, 200.5347, 201.4459],
        [200.8673, 201.1282, 200.8045, 200.6027, 200.8399, 200.9988, 200.7099,
         200.8591, 200.4129, 201.3257],
        [201.0119, 201.2712, 200.9496, 200.7472, 200.9839, 201.1415, 200.8528,
         201.0042, 200.5573, 201.4692],
        [200.8365, 201.0951, 200.7744, 200.5730, 200.8088, 200.9663, 200.6774,
         200.8300, 200.3824, 201.2938],
        [200.9307, 201

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.8773, 200.0359, 200.6139, 200.8070, 201.0608, 200.7153, 200.8017,
         200.7385, 201.7776, 200.4269],
        [200.6818, 199.8412, 200.4187, 200.6132, 200.8663, 200.5179, 200.6061,
         200.5439, 201.5824, 200.2305],
        [200.7724, 199.9320, 200.5094, 200.7032, 200.9571, 200.6081, 200.6973,
         200.6355, 201.6729, 200.3224],
        [200.9445, 200.1022, 200.6806, 200.8724, 201.1266, 200.7807, 200.8690,
         200.8071, 201.8432, 200.4948],
        [200.9590, 200.1188, 200.6958, 200.8902, 201.1446, 200.7937, 200.8838,
         200.8210, 201.8599, 200.5092],
        [200.7386, 199.8961, 200.4744, 200.6675, 200.9208, 200.5730, 200.6624,
         200.6011, 201.6369, 200.2869],
        [200.9098, 200.0700, 200.6468, 200.8414, 201.0957, 200.7444, 200.8350,
         200.7725, 201.8109, 200.4603],
        [200.8012, 199

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.7794, 200.9029, 200.5979, 201.4124, 200.5908, 200.8921, 200.5644,
         200.9327, 200.6440, 200.7163],
        [200.7353, 200.8588, 200.5532, 201.3668, 200.5461, 200.8475, 200.5207,
         200.8884, 200.5986, 200.6706],
        [200.6539, 200.7777, 200.4717, 201.2853, 200.4647, 200.7663, 200.4413,
         200.8062, 200.5181, 200.5906],
        [200.7045, 200.8286, 200.5234, 201.3401, 200.5163, 200.8183, 200.4906,
         200.8581, 200.5708, 200.6454],
        [200.8398, 200.9638, 200.6588, 201.4726, 200.6513, 200.9526, 200.6273,
         200.9924, 200.7051, 200.7777],
        [200.7393, 200.8629, 200.5573, 201.3716, 200.5503, 200.8518, 200.5246,
         200.8925, 200.6030, 200.6754],
        [200.8508, 200.9743, 200.6695, 201.4852, 200.6621, 200.9639, 200.6348,
         201.0052, 200.7148, 200.7883],
        [201.0612, 201.1852, 200.8812, 201.6954, 200.8729, 201.1741, 200.8479,
         201.2144, 200.9267, 200

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.3404, 200.7396, 200.7673, 200.8433, 200.8702, 200.5496, 200.6367,
         200.3605, 201.0549, 201.6853],
        [200.2367, 200.6351, 200.6630, 200.7385, 200.7666, 200.4453, 200.5333,
         200.2579, 200.9486, 201.5828],
        [200.3155, 200.7150, 200.7425, 200.8187, 200.8453, 200.5248, 200.6118,
         200.3357, 201.0301, 201.6605],
        [200.4027, 200.8036, 200.8300, 200.9061, 200.9316, 200.6118, 200.6980,
         200.4239, 201.1175, 201.7477],
        [200.2756, 200.6759, 200.7017, 200.7777, 200.8051, 200.4843, 200.5714,
         200.2982, 200.9881, 201.6221],
        [200.2021, 200.6017, 200.6285, 200.7045, 200.7320, 200.4110, 200.4986,
         200.2236, 200.9151, 201.5481],
        [200.5560, 200.9554, 200.9825, 201.0580, 201.0851, 200.7647, 200.8512,
         200.5777, 201.2693, 201.9013],
        [200.3410, 200

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.5791, 201.2394, 200.7147, 200.0181, 200.3706, 200.7082, 200.7204,
         200.3004, 200.8290, 200.5262],
        [200.5854, 201.2466, 200.7202, 200.0239, 200.3765, 200.7152, 200.7256,
         200.3063, 200.8349, 200.5298],
        [200.8786, 201.5384, 201.0138, 200.3183, 200.6706, 201.0055, 201.0201,
         200.5992, 201.1279, 200.8270],
        [200.5945, 201.2539, 200.7304, 200.0331, 200.3860, 200.7221, 200.7361,
         200.3149, 200.8439, 200.5427],
        [200.7807, 201.4421, 200.9152, 200.2200, 200.5722, 200.9100, 200.9211,
         200.5018, 201.0302, 200.7258],
        [200.7232, 201.3824, 200.8588, 200.1617, 200.5148, 200.8496, 200.8644,
         200.4429, 200.9720, 200.6715],
        [200.4411, 201.1014, 200.5765, 199.8785, 200.2318, 200.5701, 200.5815,
         200.1611, 200.6902, 200.3861],
        [200.3704, 201.0298, 200.5069, 199.8082, 200.1615, 200.4992, 200.5121,
         200.0910, 200.6201, 200

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[201.4684, 200.0822, 200.3672, 200.8265, 200.9358, 200.7341, 200.4963,
         200.7891, 200.8372, 200.4079],
        [201.3972, 200.0114, 200.2960, 200.7554, 200.8648, 200.6634, 200.4250,
         200.7177, 200.7663, 200.3367],
        [201.3944, 200.0084, 200.2932, 200.7533, 200.8619, 200.6600, 200.4221,
         200.7166, 200.7639, 200.3328],
        [201.3284, 199.9428, 200.2269, 200.6869, 200.7950, 200.5941, 200.3563,
         200.6504, 200.6980, 200.2664],
        [201.4291, 200.0425, 200.3277, 200.7873, 200.8969, 200.6949, 200.4569,
         200.7496, 200.7978, 200.3690],
        [201.4320, 200.0474, 200.3313, 200.7901, 200.8993, 200.6986, 200.4598,
         200.7525, 200.8015, 200.3711],
        [201.2556, 199.8682, 200.1533, 200.6140, 200.7230, 200.5211, 200.2835,
         200.5767, 200.6245, 200.1950],
        [201.4563, 200

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.5111, 200.8507, 200.7960, 200.1328, 200.7533, 200.8250, 200.5070,
         200.8721, 200.6732, 200.2948],
        [200.5640, 200.9055, 200.8504, 200.1867, 200.8068, 200.8771, 200.5633,
         200.9277, 200.7275, 200.3484],
        [200.4593, 200.7998, 200.7458, 200.0816, 200.7031, 200.7716, 200.4584,
         200.8232, 200.6198, 200.2428],
        [200.3074, 200.6486, 200.5942, 199.9296, 200.5512, 200.6218, 200.3060,
         200.6707, 200.4710, 200.0922],
        [200.3974, 200.7395, 200.6837, 200.0209, 200.6399, 200.7122, 200.3966,
         200.7605, 200.5614, 200.1828],
        [200.3308, 200.6727, 200.6175, 199.9544, 200.5746, 200.6451, 200.3305,
         200.6948, 200.4921, 200.1157],
        [200.3039, 200.6444, 200.5908, 199.9253, 200.5481, 200.6176, 200.3020,
         200.6673, 200.4667, 200.0882],
        [200.6546, 200

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.2488, 199.8502, 200.1736, 200.4099, 200.1912, 200.2740, 200.3846,
         200.2206, 199.5119, 199.6092],
        [200.5011, 200.1021, 200.4273, 200.6623, 200.4439, 200.5267, 200.6376,
         200.4735, 199.7645, 199.8624],
        [200.5148, 200.1154, 200.4398, 200.6746, 200.4574, 200.5412, 200.6517,
         200.4866, 199.7788, 199.8759],
        [200.3936, 199.9938, 200.3174, 200.5542, 200.3362, 200.4188, 200.5287,
         200.3655, 199.6563, 199.7548],
        [200.2800, 199.8808, 200.2042, 200.4402, 200.2225, 200.3060, 200.4163,
         200.2517, 199.5438, 199.6409],
        [200.3310, 199.9321, 200.2540, 200.4912, 200.2736, 200.3567, 200.4664,
         200.3024, 199.5938, 199.6918],
        [200.5353, 200.1363, 200.4610, 200.6962, 200.4779, 200.5608, 200.6717,
         200.5074, 199.7984, 199.8962],
        [200.4542, 200

tensor([[201.0502, 199.5438, 200.1342, 200.5406, 200.2481, 200.4666, 200.6486,
         200.1272, 200.2289, 199.8090],
        [201.1553, 199.6462, 200.2360, 200.6422, 200.3529, 200.5701, 200.7502,
         200.2300, 200.3329, 199.9137],
        [200.8157, 199.3073, 199.8971, 200.3039, 200.0122, 200.2305, 200.4123,
         199.8912, 199.9924, 199.5725],
        [201.1680, 199.6586, 200.2487, 200.6541, 200.3650, 200.5832, 200.7625,
         200.2432, 200.3458, 199.9269],
        [201.0712, 199.5650, 200.1551, 200.5623, 200.2695, 200.4871, 200.6700,
         200.1475, 200.2497, 199.8293],
        [201.0967, 199.5875, 200.1773, 200.5838, 200.2936, 200.5110, 200.6917,
         200.1714, 200.2743, 199.8536],
        [201.1115, 199.6045, 200.1949, 200.6014, 200.3088, 200.5273, 200.7091,
         200.1883, 200.2906, 199.8689],
        [201.1843, 199.6787, 200.2690, 200.6757, 200.3830, 200.6009, 200.7835,
         200.2612, 200.3636, 199.9435],
        [201.2415, 199.7331, 200.3236, 200.7287,

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.8887, 200.6646, 199.9385, 200.4306, 201.3474, 200.8503, 200.8926,
         200.7012, 200.8444, 200.9438],
        [200.5526, 200.3256, 199.6004, 200.0931, 201.0107, 200.5130, 200.5569,
         200.3638, 200.5083, 200.6080],
        [200.6317, 200.4085, 199.6809, 200.1718, 201.0918, 200.5936, 200.6376,
         200.4461, 200.5878, 200.6853],
        [200.3191, 200.0941, 199.3679, 199.8603, 200.7787, 200.2808, 200.3257,
         200.1340, 200.2757, 200.3753],
        [200.5540, 200.3299, 199.6042, 200.0974, 201.0132, 200.5163, 200.5596,
         200.3688, 200.5105, 200.6116],
        [200.3891, 200.1623, 199.4373, 199.9303, 200.8477, 200.3501, 200.3945,
         200.2017, 200.3453, 200.4457],
        [200.7504, 200.5271, 199.8000, 200.2919, 201.2097, 200.7122, 200.7552,
         200.5644, 200.7063, 200.8051],
        [200.4913, 200.2641, 199.5399, 200.0336, 200.94

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.4786, 200.6557, 201.0783, 200.2200, 200.5477, 200.5551, 200.3162,
         200.3669, 199.5521, 200.5786],
        [200.4627, 200.6382, 201.0602, 200.2016, 200.5306, 200.5378, 200.2999,
         200.3481, 199.5346, 200.5600],
        [200.3793, 200.5566, 200.9793, 200.1204, 200.4491, 200.4553, 200.2166,
         200.2644, 199.4523, 200.4776],
        [200.5594, 200.7365, 201.1600, 200.3008, 200.6290, 200.6355, 200.3971,
         200.4455, 199.6327, 200.6584],
        [200.5182, 200.6916, 201.1149, 200.2561, 200.5833, 200.5937, 200.3550,
         200.4060, 199.5887, 200.6160],
        [200.6729, 200.8471, 201.2708, 200.4118, 200.7391, 200.7486, 200.5104,
         200.5613, 199.7446, 200.7717],
        [200.4216, 200.5983, 201.0214, 200.1627, 200.4904, 200.4980, 200.2589,
         200.3090, 199.4946, 200.5209],
        [200.3625, 200

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.6847, 200.1822, 200.4481, 199.7913, 200.4567, 200.3936, 200.6797,
         200.2661, 201.0108, 200.2929],
        [200.5672, 200.0629, 200.3285, 199.6739, 200.3364, 200.2747, 200.5633,
         200.1468, 200.8937, 200.1758],
        [200.4505, 199.9459, 200.2117, 199.5560, 200.2201, 200.1599, 200.4439,
         200.0306, 200.7748, 200.0570],
        [200.3075, 199.8007, 200.0670, 199.4127, 200.0748, 200.0146, 200.3017,
         199.8856, 200.6322, 199.9139],
        [200.4864, 199.9817, 200.2473, 199.5926, 200.2552, 200.1943, 200.4820,
         200.0659, 200.8124, 200.0948],
        [200.4788, 199.9747, 200.2412, 199.5843, 200.2496, 200.1877, 200.4722,
         200.0593, 200.8033, 200.0854],
        [200.8221, 200.3207, 200.5855, 199.9295, 200.5942, 200.5322, 200.8180,
         200.4042, 201.1488, 200.4317],
        [200.5538, 200

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.2029, 200.6382, 200.6595, 200.8163, 201.4264, 200.6510, 200.1167,
         200.3987, 200.0272, 200.3558],
        [200.1037, 200.5399, 200.5618, 200.7185, 201.3287, 200.5522, 200.0200,
         200.3004, 199.9282, 200.2588],
        [199.7691, 200.2055, 200.2262, 200.3822, 200.9936, 200.2173, 199.6842,
         199.9644, 199.5915, 199.9224],
        [199.8286, 200.2648, 200.2865, 200.4417, 201.0532, 200.2773, 199.7445,
         200.0242, 199.6519, 199.9839],
        [200.0348, 200.4703, 200.4919, 200.6477, 201.2581, 200.4832, 199.9491,
         200.2306, 199.8589, 200.1888],
        [200.1563, 200.5919, 200.6130, 200.7700, 201.3797, 200.6043, 200.0702,
         200.3523, 199.9805, 200.3090],
        [200.1789, 200.6143, 200.6357, 200.7922, 201.4023, 200.6271, 200.0928,
         200.3748, 200.0033, 200.3322],
        [200.0081, 200.4438, 200.4644, 200.6210, 201.23

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.4475, 200.2199, 200.3708, 200.0040, 199.7924, 201.1015, 199.7604,
         200.0535, 200.1381, 200.5699],
        [200.3546, 200.1239, 200.2759, 199.9081, 199.6973, 201.0060, 199.6650,
         199.9564, 200.0432, 200.4758],
        [200.4477, 200.2177, 200.3697, 200.0024, 199.7910, 201.1000, 199.7596,
         200.0515, 200.1361, 200.5692],
        [200.4108, 200.1818, 200.3311, 199.9642, 199.7530, 201.0623, 199.7193,
         200.0117, 200.1007, 200.5318],
        [200.4932, 200.2643, 200.4138, 200.0473, 199.8354, 201.1449, 199.8022,
         200.0946, 200.1830, 200.6143],
        [200.4855, 200.2549, 200.4053, 200.0383, 199.8267, 201.1364, 199.7945,
         200.0863, 200.1729, 200.6058],
        [200.2530, 200.0234, 200.1724, 199.8047, 199.5946, 200.9040, 199.5612,
         199.8534, 199.9408, 200.3730],
        [200.3969, 200

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.1114, 200.7108, 200.3528, 200.3283, 200.0278, 200.4029, 201.2393,
         200.4203, 200.4364, 201.3237],
        [199.7775, 200.3742, 200.0160, 199.9913, 199.6920, 200.0653, 200.9035,
         200.0809, 200.0986, 200.9852],
        [200.0343, 200.6322, 200.2748, 200.2503, 199.9504, 200.3248, 201.1611,
         200.3408, 200.3578, 201.2451],
        [200.0414, 200.6404, 200.2821, 200.2572, 199.9575, 200.3319, 201.1686,
         200.3474, 200.3649, 201.2526],
        [200.0540, 200.6524, 200.2927, 200.2671, 199.9684, 200.3411, 201.1809,
         200.3592, 200.3745, 201.2619],
        [200.1278, 200.7272, 200.3690, 200.3444, 200.0442, 200.4190, 201.2555,
         200.4360, 200.4522, 201.3397],
        [200.0510, 200.6497, 200.2915, 200.2665, 199.9671, 200.3411, 201.1780,
         200.3566, 200.3741, 201.2618],
        [199.9796, 200.5785, 200.2203, 200.1958, 199.89

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.2990, 200.0499, 200.2157, 200.1491, 200.3043, 200.3524, 199.7205,
         199.6491, 199.8522, 200.2438],
        [200.3421, 200.0949, 200.2609, 200.1922, 200.3476, 200.3963, 199.7636,
         199.6913, 199.8956, 200.2857],
        [200.0562, 199.8072, 199.9728, 199.9079, 200.0619, 200.1107, 199.4783,
         199.4053, 199.6099, 200.0021],
        [200.4754, 200.2264, 200.3918, 200.3250, 200.4812, 200.5297, 199.8968,
         199.8264, 200.0288, 200.4197],
        [200.2278, 199.9793, 200.1437, 200.0784, 200.2340, 200.2839, 199.6498,
         199.5779, 199.7816, 200.1730],
        [200.3221, 200.0743, 200.2390, 200.1720, 200.3282, 200.3781, 199.7441,
         199.6722, 199.8758, 200.2665],
        [200.3430, 200.0954, 200.2600, 200.1924, 200.3489, 200.3989, 199.7649,
         199.6931, 199.8965, 200.2870],
        [200.2639, 200

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.9008, 200.1340, 199.8509, 199.9258, 199.9850, 199.7669, 199.9013,
         200.2440, 199.9542, 200.0660],
        [200.0314, 200.2653, 199.9819, 200.0560, 200.1153, 199.8983, 200.0329,
         200.3752, 200.0845, 200.1968],
        [199.9807, 200.2131, 199.9302, 200.0051, 200.0634, 199.8465, 199.9796,
         200.3234, 200.0333, 200.1448],
        [200.0287, 200.2623, 199.9796, 200.0540, 200.1134, 199.8958, 200.0298,
         200.3726, 200.0826, 200.1939],
        [200.2089, 200.4417, 200.1589, 200.2331, 200.2912, 200.0759, 200.2083,
         200.5520, 200.2612, 200.3730],
        [199.9959, 200.2267, 199.9437, 200.0190, 200.0761, 199.8602, 199.9913,
         200.3376, 200.0469, 200.1570],
        [200.0413, 200.2723, 199.9895, 200.0647, 200.1219, 199.9060, 200.0371,
         200.3833, 200.0926, 200.2027],
        [200.1616, 200

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.5551, 200.4704, 199.0965, 199.9365, 199.7265, 199.7802, 199.8276,
         199.5486, 199.5131, 199.6480],
        [199.7920, 200.7047, 199.3352, 200.1712, 199.9634, 200.0154, 200.0623,
         199.7844, 199.7501, 199.8835],
        [199.8365, 200.7504, 199.3806, 200.2171, 200.0078, 200.0614, 200.1075,
         199.8289, 199.7959, 199.9302],
        [199.9299, 200.8450, 199.4731, 200.3114, 200.1014, 200.1542, 200.1996,
         199.9227, 199.8893, 200.0241],
        [199.7808, 200.6944, 199.3235, 200.1605, 199.9524, 200.0052, 200.0523,
         199.7741, 199.7393, 199.8718],
        [199.7532, 200.6661, 199.2952, 200.1323, 199.9247, 199.9759, 200.0231,
         199.7459, 199.7105, 199.8435],
        [199.7283, 200.6407, 199.2710, 200.1079, 199.8994, 199.9500, 199.9966,
         199.7197, 199.6854, 199.8212],
        [199.9267, 200

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.9673, 200.0348, 200.2147, 200.0612, 199.9819, 200.0155, 199.7682,
         200.3253, 200.2944, 200.6741],
        [199.7152, 199.7838, 199.9637, 199.8089, 199.7313, 199.7638, 199.5170,
         200.0740, 200.0421, 200.4243],
        [199.9222, 199.9910, 200.1715, 200.0168, 199.9375, 199.9716, 199.7247,
         200.2816, 200.2487, 200.6303],
        [199.7961, 199.8661, 200.0458, 199.8917, 199.8110, 199.8459, 199.5997,
         200.1558, 200.1235, 200.5047],
        [199.7903, 199.8608, 200.0402, 199.8853, 199.8071, 199.8408, 199.5939,
         200.1505, 200.1170, 200.5013],
        [199.7330, 199.8030, 199.9825, 199.8279, 199.7487, 199.7827, 199.5363,
         200.0926, 200.0601, 200.4427],
        [199.8836, 199.9535, 200.1319, 199.9787, 199.8995, 199.9338, 199.6865,
         200.2429, 200.2112, 200.5937],
        [200.0510, 200.1201, 200.3007, 200.1451, 200.06

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.9291, 199.7729, 199.8317, 199.9931, 199.9814, 200.0837, 199.6135,
         199.1994, 199.7342, 200.3456],
        [200.0766, 199.9219, 199.9813, 200.1418, 200.1302, 200.2329, 199.7634,
         199.3481, 199.8810, 200.4933],
        [200.2553, 200.1008, 200.1589, 200.3200, 200.3094, 200.4118, 199.9434,
         199.5286, 200.0603, 200.6716],
        [200.0435, 199.8874, 199.9468, 200.1061, 200.0971, 200.1989, 199.7295,
         199.3156, 199.8480, 200.4595],
        [200.0794, 199.9238, 199.9831, 200.1427, 200.1331, 200.2352, 199.7659,
         199.3517, 199.8838, 200.4954],
        [199.9694, 199.8168, 199.8754, 200.0379, 200.0245, 200.1279, 199.6575,
         199.2406, 199.7751, 200.3860],
        [200.0405, 199.8868, 199.9439, 200.1076, 200.0946, 200.1975, 199.7278,
         199.3122, 199.8469, 200.4570],
        [200.0096, 199

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.8692, 200.2324, 200.1949, 199.7738, 200.0468, 199.6945, 199.8929,
         199.9080, 200.0585, 199.9290],
        [200.7293, 200.0928, 200.0563, 199.6337, 199.9082, 199.5544, 199.7517,
         199.7685, 199.9160, 199.7895],
        [200.8570, 200.2206, 200.1839, 199.7618, 200.0352, 199.6819, 199.8802,
         199.8969, 200.0446, 199.9174],
        [200.7951, 200.1590, 200.1223, 199.6997, 199.9730, 199.6200, 199.8189,
         199.8357, 199.9830, 199.8558],
        [200.7144, 200.0783, 200.0414, 199.6186, 199.8920, 199.5395, 199.7386,
         199.7547, 199.9031, 199.7749],
        [200.7949, 200.1580, 200.1216, 199.6989, 199.9732, 199.6201, 199.8172,
         199.8331, 199.9821, 199.8548],
        [200.8765, 200.2398, 200.2021, 199.7812, 200.0538, 199.7018, 199.9008,
         199.9155, 200.0666, 199.9363],
        [200.7022, 200

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.5845, 200.0596, 200.0061, 199.7732, 199.4012, 199.9861, 199.7939,
         199.7753, 200.1404, 200.1931],
        [199.6442, 200.1229, 200.0673, 199.8361, 199.4626, 200.0491, 199.8565,
         199.8381, 200.2037, 200.2528],
        [199.6211, 200.0972, 200.0441, 199.8103, 199.4381, 200.0243, 199.8313,
         199.8129, 200.1781, 200.2287],
        [199.4882, 199.9681, 199.9129, 199.6799, 199.3066, 199.8949, 199.7012,
         199.6830, 200.0500, 200.0971],
        [199.4930, 199.9715, 199.9173, 199.6838, 199.3107, 199.8985, 199.7049,
         199.6864, 200.0532, 200.1015],
        [199.5566, 200.0360, 199.9799, 199.7478, 199.3752, 199.9622, 199.7692,
         199.7515, 200.1176, 200.1658],
        [199.5275, 200.0041, 199.9501, 199.7177, 199.3446, 199.9306, 199.7380,
         199.7191, 200.0850, 200.1364],
        [199.6957, 200

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.2184, 199.9491, 199.9374, 199.8397, 200.0905, 199.9992, 200.1535,
         199.9747, 200.1485, 199.9159],
        [200.4242, 200.1563, 200.1419, 200.0461, 200.2974, 200.2063, 200.3601,
         200.1806, 200.3559, 200.1238],
        [199.9780, 199.7115, 199.6959, 199.6007, 199.8494, 199.7608, 199.9160,
         199.7341, 199.9089, 199.6780],
        [200.0133, 199.7456, 199.7306, 199.6352, 199.8857, 199.7960, 199.9510,
         199.7691, 199.9452, 199.7127],
        [200.0966, 199.8279, 199.8140, 199.7182, 199.9696, 199.8789, 200.0337,
         199.8525, 200.0289, 199.7955],
        [200.1571, 199.8893, 199.8745, 199.7801, 200.0294, 199.9404, 200.0959,
         199.9133, 200.0896, 199.8570],
        [200.3006, 200.0348, 200.0177, 199.9233, 200.1731, 200.0837, 200.2378,
         200.0569, 200.2322, 200.0017],
        [200.2994, 200.0319, 200.0169, 199.9212, 200.17

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.4557, 199.6130, 200.1092, 199.7773, 199.4254, 199.9274, 199.2373,
         199.8587, 200.7373, 199.5275],
        [200.6002, 199.7566, 200.2559, 199.9223, 199.5687, 200.0701, 199.3812,
         200.0049, 200.8803, 199.6727],
        [200.5699, 199.7283, 200.2266, 199.8933, 199.5392, 200.0405, 199.3524,
         199.9756, 200.8513, 199.6441],
        [200.7955, 199.9537, 200.4509, 200.1185, 199.7656, 200.2652, 199.5780,
         200.1990, 201.0767, 199.8695],
        [200.5549, 199.7128, 200.2077, 199.8767, 199.5253, 200.0264, 199.3370,
         199.9570, 200.8368, 199.6269],
        [200.5908, 199.7532, 200.2483, 199.9180, 199.5616, 200.0611, 199.3764,
         199.9971, 200.8741, 199.6685],
        [200.3955, 199.5536, 200.0516, 199.7184, 199.3644, 199.8670, 199.1777,
         199.8015, 200.6770, 199.4689],
        [200.5543, 199

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.8537, 199.9851, 199.9514, 199.9306, 199.0663, 199.6995, 199.8483,
         199.5230, 199.6113, 200.1122],
        [200.0049, 200.1345, 200.1023, 200.0821, 199.2188, 199.8507, 199.9997,
         199.6742, 199.7622, 200.2618],
        [199.8164, 199.9482, 199.9142, 199.8934, 199.0282, 199.6625, 199.8109,
         199.4852, 199.5739, 200.0747],
        [199.8207, 199.9508, 199.9182, 199.8983, 199.0361, 199.6670, 199.8162,
         199.4923, 199.5784, 200.0775],
        [199.8544, 199.9845, 199.9523, 199.9314, 199.0693, 199.7003, 199.8502,
         199.5244, 199.6116, 200.1109],
        [200.1005, 200.2296, 200.1980, 200.1774, 199.3160, 199.9458, 200.0957,
         199.7704, 199.8579, 200.3578],
        [199.9238, 200.0531, 200.0214, 200.0014, 199.1391, 199.7701, 199.9194,
         199.5941, 199.6811, 200.1796],
        [199.5736, 199

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.3081, 200.4364, 199.6181, 199.9342, 199.2735, 199.6523, 199.0695,
         199.2415, 199.9382, 199.6214],
        [200.2262, 200.3539, 199.5374, 199.8520, 199.1925, 199.5703, 198.9887,
         199.1605, 199.8563, 199.5385],
        [200.3511, 200.4781, 199.6562, 199.9723, 199.3144, 199.6935, 199.1101,
         199.2806, 199.9805, 199.6616],
        [200.2386, 200.3656, 199.5464, 199.8627, 199.2038, 199.5820, 199.0008,
         199.1706, 199.8688, 199.5499],
        [200.3877, 200.5145, 199.6931, 200.0092, 199.3512, 199.7305, 199.1476,
         199.3177, 200.0169, 199.6984],
        [200.1534, 200.2815, 199.4634, 199.7780, 199.1187, 199.4964, 198.9128,
         199.0863, 199.7831, 199.4648],
        [200.2724, 200.3994, 199.5815, 199.8957, 199.2377, 199.6157, 199.0339,
         199.2050, 199.9019, 199.5832],
        [200.1469, 200

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.4881, 200.5325, 200.2639, 199.6083, 199.7542, 200.0055, 199.2027,
         199.4737, 200.4131, 199.9407],
        [199.5763, 200.6212, 200.3528, 199.6971, 199.8424, 200.0938, 199.2916,
         199.5624, 200.5021, 200.0294],
        [199.5081, 200.5522, 200.2834, 199.6278, 199.7731, 200.0239, 199.2221,
         199.4942, 200.4314, 199.9612],
        [199.3531, 200.3965, 200.1287, 199.4718, 199.6161, 199.8682, 199.0654,
         199.3382, 200.2747, 199.8066],
        [199.3879, 200.4339, 200.1659, 199.5091, 199.6553, 199.9068, 199.1036,
         199.3749, 200.3152, 199.8430],
        [199.5367, 200.5820, 200.3134, 199.6577, 199.8039, 200.0549, 199.2524,
         199.5231, 200.4633, 199.9901],
        [199.4150, 200.4581, 200.1905, 199.5337, 199.6774, 199.9299, 199.1271,
         199.3997, 200.3365, 199.8680],
        [199.5597, 200

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[200.0579, 199.7143, 199.7084, 199.6973, 200.0382, 200.0004, 199.5807,
         199.8914, 200.5315, 199.9513],
        [199.9368, 199.5903, 199.5876, 199.5775, 199.9182, 199.8805, 199.4586,
         199.7690, 200.4115, 199.8299],
        [199.7526, 199.4077, 199.4040, 199.3940, 199.7344, 199.6958, 199.2736,
         199.5854, 200.2266, 199.6459],
        [199.9725, 199.6275, 199.6234, 199.6139, 199.9532, 199.9147, 199.4944,
         199.8044, 200.4451, 199.8655],
        [199.6478, 199.3027, 199.2989, 199.2892, 199.6293, 199.5913, 199.1683,
         199.4812, 200.1221, 199.5409],
        [199.9529, 199.6067, 199.6039, 199.5933, 199.9345, 199.8965, 199.4749,
         199.7852, 200.4275, 199.8461],
        [199.8912, 199.5445, 199.5419, 199.5319, 199.8725, 199.8349, 199.4128,
         199.7234, 200.3657, 199.7841],
        [199.7504, 199

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.7833, 198.7702, 200.4060, 199.6248, 199.8171, 199.7719, 199.5883,
         199.6933, 199.3549, 199.0711],
        [199.8794, 198.8662, 200.5019, 199.7210, 199.9120, 199.8681, 199.6843,
         199.7883, 199.4509, 199.1671],
        [199.9265, 198.9155, 200.5500, 199.7718, 199.9586, 199.9157, 199.7305,
         199.8353, 199.5007, 199.2153],
        [199.7393, 198.7285, 200.3627, 199.5846, 199.7722, 199.7288, 199.5423,
         199.6487, 199.3134, 199.0274],
        [199.9713, 198.9584, 200.5942, 199.8140, 200.0034, 199.9596, 199.7769,
         199.8797, 199.5437, 199.2605],
        [199.7874, 198.7745, 200.4101, 199.6294, 199.8208, 199.7761, 199.5920,
         199.6971, 199.3593, 199.0752],
        [199.7722, 198.7607, 200.3953, 199.6160, 199.8054, 199.7616, 199.5757,
         199.6821, 199.3453, 199.0596],
        [199.8542, 198.8423, 200.4771, 199.6988, 199.88

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.9600, 199.4161, 199.2375, 199.6480, 199.5243, 199.4216, 200.1558,
         200.2951, 199.3372, 199.5158],
        [199.0330, 199.4871, 199.3085, 199.7171, 199.5947, 199.4934, 200.2249,
         200.3655, 199.4057, 199.5869],
        [199.1767, 199.6324, 199.4525, 199.8637, 199.7409, 199.6366, 200.3689,
         200.5104, 199.5524, 199.7330],
        [199.1585, 199.6124, 199.4342, 199.8419, 199.7207, 199.6184, 200.3524,
         200.4919, 199.5338, 199.7113],
        [199.2352, 199.6903, 199.5111, 199.9210, 199.7991, 199.6948, 200.4290,
         200.5693, 199.6119, 199.7899],
        [199.3335, 199.7884, 199.6091, 200.0181, 199.8968, 199.7927, 200.5263,
         200.6671, 199.7093, 199.8878],
        [199.0433, 199.4980, 199.3188, 199.7289, 199.6061, 199.5035, 200.2352,
         200.3762, 199.4172, 199.5984],
        [199.1929, 199

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.3685, 199.0759, 199.8572, 199.0428, 199.7065, 199.3193, 199.5506,
         199.3988, 199.7335, 199.7629],
        [199.4647, 199.1718, 199.9534, 199.1381, 199.8027, 199.4150, 199.6457,
         199.4964, 199.8289, 199.8596],
        [199.0305, 198.7392, 199.5210, 198.7053, 199.3694, 198.9826, 199.2118,
         199.0611, 199.3973, 199.4258],
        [199.2238, 198.9317, 199.7135, 198.8978, 199.5626, 199.1754, 199.4050,
         199.2546, 199.5897, 199.6189],
        [199.4440, 199.1511, 199.9329, 199.1174, 199.7822, 199.3946, 199.6251,
         199.4754, 199.8086, 199.8388],
        [199.4398, 199.1470, 199.9294, 199.1145, 199.7764, 199.3903, 199.6201,
         199.4707, 199.8052, 199.8346],
        [199.2666, 198.9741, 199.7567, 198.9421, 199.6035, 199.2184, 199.4473,
         199.2959, 199.6338, 199.6610],
        [199.4785, 199

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.8764, 199.8327, 199.8533, 198.9868, 200.5543, 199.6745, 200.4660,
         200.0031, 198.8331, 199.5083],
        [199.6939, 199.6492, 199.6692, 198.8036, 200.3707, 199.4895, 200.2819,
         199.8199, 198.6492, 199.3252],
        [199.9243, 199.8810, 199.9019, 199.0352, 200.6019, 199.7230, 200.5141,
         200.0511, 198.8810, 199.5575],
        [199.5002, 199.4556, 199.4752, 198.6093, 200.1767, 199.2941, 200.0880,
         199.6253, 198.4565, 199.1312],
        [199.9128, 199.8689, 199.8897, 199.0232, 200.5911, 199.7117, 200.5025,
         200.0394, 198.8694, 199.5440],
        [199.9986, 199.9559, 199.9765, 199.1100, 200.6767, 199.7959, 200.5890,
         200.1266, 198.9566, 199.6323],
        [199.8316, 199.7888, 199.8092, 198.9423, 200.5096, 199.6284, 200.4218,
         199.9585, 198.7902, 199.4647],
        [199.7759, 199

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.3971, 199.1694, 199.2015, 199.5651, 199.2415, 199.2106, 199.6519,
         198.8778, 200.1772, 199.4041],
        [199.5256, 199.2971, 199.3276, 199.6916, 199.3687, 199.3344, 199.7770,
         199.0045, 200.3040, 199.5301],
        [199.6338, 199.4071, 199.4380, 199.8019, 199.4782, 199.4471, 199.8882,
         199.1141, 200.4141, 199.6401],
        [199.7058, 199.4780, 199.5086, 199.8723, 199.5488, 199.5143, 199.9577,
         199.1848, 200.4841, 199.7104],
        [199.4823, 199.2546, 199.2860, 199.6498, 199.3261, 199.2944, 199.7363,
         198.9624, 200.2618, 199.4883],
        [199.7847, 199.5586, 199.5894, 199.9530, 199.6297, 199.5975, 200.0389,
         199.2649, 200.5652, 199.7915],
        [199.5725, 199.3457, 199.3767, 199.7406, 199.4168, 199.3860, 199.8271,
         199.0529, 200.3528, 199.5789],
        [199.6201, 199

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.1186, 199.4792, 199.4154, 199.5464, 200.1492, 199.6553, 199.4164,
         199.7696, 200.5315, 199.5085],
        [198.9147, 199.2757, 199.2097, 199.3420, 199.9439, 199.4510, 199.2132,
         199.5659, 200.3258, 199.3060],
        [199.0624, 199.4247, 199.3571, 199.4897, 200.0908, 199.5992, 199.3597,
         199.7138, 200.4738, 199.4529],
        [199.0936, 199.4542, 199.3905, 199.5216, 200.1243, 199.6306, 199.3914,
         199.7447, 200.5067, 199.4837],
        [199.0423, 199.4043, 199.3369, 199.4691, 200.0705, 199.5785, 199.3403,
         199.6936, 200.4533, 199.4328],
        [199.0527, 199.4129, 199.3490, 199.4810, 200.0834, 199.5898, 199.3503,
         199.7036, 200.4653, 199.4436],
        [199.3327, 199.6938, 199.6280, 199.7600, 200.3615, 199.8692, 199.6293,
         199.9837, 200.7446, 199.7226],
        [198.7679, 199

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.6290, 199.6116, 199.4685, 198.9614, 199.2990, 200.2593, 199.3049,
         198.9337, 199.4871, 199.6661],
        [198.6416, 199.6237, 199.4823, 198.9743, 199.3121, 200.2723, 199.3164,
         198.9464, 199.4986, 199.6779],
        [198.8447, 199.8277, 199.6827, 199.1758, 199.5143, 200.4746, 199.5192,
         199.1501, 199.7019, 199.8803],
        [198.6861, 199.6695, 199.5235, 199.0173, 199.3555, 200.3157, 199.3614,
         198.9908, 199.5440, 199.7225],
        [198.4666, 199.4508, 199.3029, 198.7971, 199.1353, 200.0955, 199.1395,
         198.7700, 199.3223, 199.5008],
        [198.6785, 199.6621, 199.5153, 199.0089, 199.3473, 200.3075, 199.3516,
         198.9828, 199.5343, 199.7126],
        [198.5101, 199.4926, 199.3502, 198.8426, 199.1805, 200.1406, 199.1844,
         198.8142, 199.3668, 199.5464],
        [198.5142, 199

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.5913, 199.4488, 200.0821, 199.4061, 199.6277, 199.5026, 199.2153,
         199.6006, 199.2909, 199.5733],
        [198.4622, 199.3170, 199.9509, 199.2750, 199.4957, 199.3726, 199.0851,
         199.4702, 199.1604, 199.4449],
        [198.4899, 199.3443, 199.9780, 199.3023, 199.5228, 199.4000, 199.1126,
         199.4976, 199.1878, 199.4727],
        [198.4860, 199.3427, 199.9753, 199.2978, 199.5212, 199.3957, 199.1083,
         199.4937, 199.1843, 199.4693],
        [198.8221, 199.6781, 200.3113, 199.6361, 199.8565, 199.7330, 199.4460,
         199.8302, 199.5210, 199.8042],
        [198.5332, 199.3899, 200.0242, 199.3472, 199.5690, 199.4442, 199.1567,
         199.5420, 199.2326, 199.5156],
        [198.5515, 199.4082, 200.0428, 199.3664, 199.5875, 199.4630, 199.1755,
         199.5608, 199.2511, 199.5336],
        [198.4955, 199.3521, 199.9871, 199.3103, 199.53

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.9532, 199.2051, 199.6630, 199.6172, 199.3212, 199.5402, 198.7146,
         199.2968, 199.3791, 199.7332],
        [198.5598, 198.8105, 199.2703, 199.2226, 198.9264, 199.1459, 198.3197,
         198.9005, 198.9874, 199.3393],
        [198.9892, 199.2408, 199.6985, 199.6519, 199.3568, 199.5750, 198.7506,
         199.3307, 199.4164, 199.7693],
        [199.0652, 199.3168, 199.7742, 199.7277, 199.4328, 199.6506, 198.8268,
         199.4067, 199.4925, 199.8453],
        [198.9065, 199.1586, 199.6161, 199.5693, 199.2740, 199.4926, 198.6681,
         199.2477, 199.3343, 199.6870],
        [198.8728, 199.1247, 199.5808, 199.5345, 199.2393, 199.4567, 198.6328,
         199.2140, 199.2997, 199.6530],
        [198.8706, 199.1210, 199.5812, 199.5338, 199.2381, 199.4574, 198.6318,
         199.2119, 199.2980, 199.6500],
        [198.7785, 199

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.5588, 199.3990, 198.8508, 199.6381, 199.3954, 199.5347, 199.4490,
         199.8884, 199.7608, 199.7034],
        [199.4327, 199.2716, 198.7215, 199.5106, 199.2705, 199.4082, 199.3218,
         199.7601, 199.6341, 199.5746],
        [199.2871, 199.1246, 198.5771, 199.3648, 199.1238, 199.2616, 199.1761,
         199.6141, 199.4886, 199.4306],
        [199.3205, 199.1576, 198.6110, 199.3969, 199.1580, 199.2945, 199.2092,
         199.6461, 199.5210, 199.4640],
        [199.3627, 199.2031, 198.6535, 199.4442, 199.1985, 199.3386, 199.2537,
         199.6937, 199.5670, 199.5086],
        [199.2366, 199.0765, 198.5274, 199.3177, 199.0726, 199.2119, 199.1276,
         199.5669, 199.4409, 199.3833],
        [199.1487, 198.9872, 198.4377, 199.2284, 198.9851, 199.1235, 199.0387,
         199.4773, 199.3522, 199.2935],
        [199.3340, 199.1746, 198.6237, 199.4152, 199.17

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.1659, 199.1488, 199.2457, 199.2504, 199.6609, 198.9782, 199.5344,
         198.8244, 199.7572, 198.7979],
        [199.2182, 199.1987, 199.2978, 199.3021, 199.7097, 199.0306, 199.5833,
         198.8749, 199.8102, 198.8493],
        [199.0377, 199.0135, 199.1138, 199.1204, 199.5254, 198.8456, 199.4010,
         198.6921, 199.6254, 198.6645],
        [199.3367, 199.3166, 199.4147, 199.4192, 199.8280, 199.1475, 199.7009,
         198.9927, 199.9256, 198.9668],
        [199.1379, 199.1188, 199.2171, 199.2225, 199.6318, 198.9489, 199.5057,
         198.7953, 199.7284, 198.7688],
        [199.1223, 199.1028, 199.2019, 199.2063, 199.6130, 198.9350, 199.4872,
         198.7792, 199.7147, 198.7531],
        [199.2419, 199.2201, 199.3188, 199.3243, 199.7322, 199.0510, 199.6060,
         198.8971, 199.8295, 198.8704],
        [199.2028, 199

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.3136, 198.9240, 199.1936, 199.1977, 199.6008, 199.1741, 199.0136,
         198.9988, 199.3242, 199.3199],
        [199.3849, 198.9971, 199.2642, 199.2704, 199.6747, 199.2471, 199.0856,
         199.0714, 199.3975, 199.3923],
        [199.2945, 198.9082, 199.1740, 199.1802, 199.5845, 199.1581, 198.9961,
         198.9810, 199.3084, 199.3021],
        [199.5161, 199.1280, 199.3957, 199.4005, 199.8036, 199.3777, 199.2169,
         199.2011, 199.5272, 199.5217],
        [199.1075, 198.7216, 198.9876, 198.9932, 199.3974, 198.9717, 198.8094,
         198.7941, 199.1221, 199.1156],
        [199.6409, 199.2529, 199.5205, 199.5255, 199.9289, 199.5021, 199.3424,
         199.3266, 199.6523, 199.6465],
        [199.5110, 199.1274, 199.3905, 199.3969, 199.8011, 199.3761, 199.2151,
         199.1976, 199.5264, 199.5177],
        [199.6469, 199

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.3234, 199.9534, 199.0787, 198.8656, 198.9955, 199.0835, 198.9810,
         199.1818, 199.1467, 199.8344],
        [199.4630, 200.0932, 199.2198, 199.0057, 199.1352, 199.2227, 199.1211,
         199.3215, 199.2892, 199.9757],
        [199.3012, 199.9318, 199.0581, 198.8440, 198.9741, 199.0608, 198.9582,
         199.1582, 199.1264, 199.8114],
        [199.3895, 200.0190, 199.1444, 198.9314, 199.0609, 199.1495, 199.0478,
         199.2491, 199.2135, 199.9020],
        [199.2514, 199.8813, 199.0077, 198.7926, 198.9234, 199.0112, 198.9086,
         199.1096, 199.0773, 199.7636],
        [199.3328, 199.9626, 199.0889, 198.8753, 199.0045, 199.0916, 198.9909,
         199.1911, 199.1585, 199.8447],
        [199.4851, 200.1154, 199.2421, 199.0279, 199.1574, 199.2448, 199.1433,
         199.3437, 199.3118, 199.9981],
        [199.2848, 199

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[199.0943, 199.2842, 199.0718, 199.4137, 199.2558, 200.2645, 198.6557,
         199.3124, 199.3843, 199.3018],
        [199.0345, 199.2243, 199.0095, 199.3532, 199.1953, 200.2053, 198.5943,
         199.2539, 199.3237, 199.2409],
        [199.1047, 199.2948, 199.0788, 199.4241, 199.2661, 200.2757, 198.6651,
         199.3237, 199.3924, 199.3116],
        [199.2248, 199.4152, 199.2025, 199.5448, 199.3864, 200.3950, 198.7881,
         199.4423, 199.5134, 199.4331],
        [199.1749, 199.3649, 199.1510, 199.4947, 199.3369, 200.3456, 198.7367,
         199.3929, 199.4635, 199.3827],
        [198.9723, 199.1623, 198.9500, 199.2924, 199.1348, 200.1424, 198.5336,
         199.1898, 199.2621, 199.1809],
        [199.0614, 199.2510, 199.0386, 199.3814, 199.2240, 200.2316, 198.6225,
         199.2788, 199.3513, 199.2696],
        [199.0595, 199

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.9069, 199.2224, 199.0935, 199.4963, 199.2423, 198.8572, 199.4005,
         198.8907, 198.9116, 199.0390],
        [198.9355, 199.2500, 199.1236, 199.5278, 199.2693, 198.8847, 199.4270,
         198.9219, 198.9387, 199.0642],
        [199.1167, 199.4303, 199.3036, 199.7062, 199.4507, 199.0664, 199.6090,
         199.1011, 199.1203, 199.2471],
        [198.7090, 199.0252, 198.8963, 199.3001, 199.0445, 198.6590, 199.2016,
         198.6941, 198.7130, 198.8404],
        [198.9303, 199.2450, 199.1165, 199.5192, 199.2652, 198.8801, 199.4241,
         198.9139, 198.9348, 199.0617],
        [198.9865, 199.3008, 199.1727, 199.5752, 199.3211, 198.9362, 199.4798,
         198.9699, 198.9906, 199.1177],
        [199.0169, 199.3314, 199.2035, 199.6068, 199.3512, 198.9665, 199.5106,
         199.0010, 199.0215, 199.1472],
        [199.0815, 199

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.5632, 199.2969, 199.9299, 199.0623, 199.2193, 198.4337, 199.3110,
         198.9728, 199.1158, 198.3966],
        [198.7566, 199.4879, 200.1211, 199.2534, 199.4108, 198.6245, 199.5025,
         199.1642, 199.3094, 198.5884],
        [198.6341, 199.3683, 200.0011, 199.1332, 199.2910, 198.5047, 199.3819,
         199.0441, 199.1869, 198.4682],
        [198.7756, 199.5077, 200.1408, 199.2722, 199.4303, 198.6447, 199.5220,
         199.1837, 199.3283, 198.6077],
        [198.5623, 199.2949, 199.9281, 199.0607, 199.2169, 198.4322, 199.3096,
         198.9709, 199.1148, 198.3945],
        [198.5292, 199.2625, 199.8950, 199.0270, 199.1848, 198.3983, 199.2766,
         198.9377, 199.0818, 198.3614],
        [198.4910, 199.2256, 199.8572, 198.9896, 199.1491, 198.3576, 199.2389,
         198.9003, 199.0448, 198.3243],
        [198.4210, 199

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.9008, 198.6590, 199.0808, 199.3237, 198.9858, 199.1645, 198.9824,
         199.4914, 199.0210, 198.6262],
        [198.9254, 198.6836, 199.1046, 199.3489, 199.0100, 199.1888, 199.0080,
         199.5184, 199.0467, 198.6497],
        [198.7572, 198.5148, 198.9359, 199.1786, 198.8393, 199.0216, 198.8360,
         199.3465, 198.8757, 198.4829],
        [198.7781, 198.5360, 198.9563, 199.2008, 198.8610, 199.0419, 198.8588,
         199.3703, 198.8984, 198.5024],
        [198.9247, 198.6827, 199.1048, 199.3476, 199.0099, 199.1879, 199.0064,
         199.5157, 199.0450, 198.6497],
        [199.0843, 198.8420, 199.2636, 199.5071, 199.1682, 199.3480, 199.1662,
         199.6765, 199.2050, 198.8092],
        [198.9398, 198.6975, 199.1197, 199.3624, 199.0245, 199.2028, 199.0211,
         199.5307, 199.0599, 198.6647],
        [198.8604, 198

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.9710, 198.8326, 198.8579, 199.0654, 199.1562, 198.8690, 199.0092,
         199.3249, 200.0276, 199.9459],
        [199.0976, 198.9585, 198.9841, 199.1911, 199.2823, 198.9955, 199.1351,
         199.4516, 200.1539, 200.0714],
        [199.0885, 198.9515, 198.9737, 199.1814, 199.2714, 198.9852, 199.1263,
         199.4402, 200.1427, 200.0627],
        [199.0902, 198.9541, 198.9752, 199.1839, 199.2733, 198.9868, 199.1284,
         199.4407, 200.1438, 200.0649],
        [198.9873, 198.8489, 198.8728, 199.0801, 199.1707, 198.8838, 199.0252,
         199.3400, 200.0421, 199.9604],
        [198.9027, 198.7661, 198.7878, 198.9957, 199.0857, 198.7987, 198.9415,
         199.2540, 199.9567, 199.8767],
        [198.9099, 198.7731, 198.7976, 199.0057, 199.0963, 198.8086, 198.9490,
         199.2638, 199.9671, 199.8869],
        [199.0692, 198

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.8452, 199.0253, 199.7566, 199.2899, 199.2783, 199.2690, 198.8944,
         199.2470, 199.8968, 199.4355],
        [198.7016, 198.8804, 199.6130, 199.1448, 199.1342, 199.1252, 198.7499,
         199.1002, 199.7523, 199.2907],
        [198.7118, 198.8912, 199.6231, 199.1557, 199.1447, 199.1354, 198.7604,
         199.1118, 199.7629, 199.3012],
        [198.9536, 199.1323, 199.8664, 199.3985, 199.3876, 199.3775, 199.0036,
         199.3555, 200.0053, 199.5450],
        [198.7186, 198.8966, 199.6322, 199.1622, 199.1520, 199.1434, 198.7671,
         199.1186, 199.7697, 199.3106],
        [198.5454, 198.7245, 199.4582, 198.9896, 198.9787, 198.9698, 198.5936,
         198.9464, 199.5968, 199.1364],
        [198.7744, 198.9525, 199.6878, 199.2183, 199.2079, 199.1989, 198.8231,
         199.1749, 199.8256, 199.3663],
        [198.7009, 198

tensor([[199.1918, 198.9264, 199.5927, 199.0443, 199.2073, 198.9054, 198.9307,
         199.0688, 199.2388, 199.0918],
        [199.0731, 198.8081, 199.4737, 198.9258, 199.0895, 198.7891, 198.8126,
         198.9515, 199.1210, 198.9750],
        [199.1344, 198.8691, 199.5346, 198.9879, 199.1490, 198.8515, 198.8743,
         199.0129, 199.1822, 199.0357],
        [199.0495, 198.7842, 199.4504, 198.9014, 199.0654, 198.7631, 198.7884,
         198.9264, 199.0967, 198.9500],
        [199.0433, 198.7780, 199.4443, 198.8954, 199.0592, 198.7573, 198.7822,
         198.9206, 199.0906, 198.9437],
        [199.1757, 198.9104, 199.5769, 199.0280, 199.1915, 198.8889, 198.9146,
         199.0525, 199.2227, 199.0757],
        [199.1646, 198.8994, 199.5651, 199.0189, 199.1796, 198.8820, 198.9042,
         199.0440, 199.2127, 199.0655],
        [199.3700, 199.1053, 199.7710, 199.2239, 199.3864, 199.0856, 199.1096,
         199.2484, 199.4177, 199.2713],
        [199.0605, 198.7947, 199.4614, 198.9140,

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.8661, 198.3463, 198.8746, 198.7707, 199.0536, 198.8156, 198.9892,
         198.9200, 198.6872, 199.0467],
        [199.0592, 198.5412, 199.0692, 198.9665, 199.2487, 199.0119, 199.1850,
         199.1149, 198.8833, 199.2428],
        [199.0932, 198.5750, 199.1030, 198.9995, 199.2820, 199.0461, 199.2183,
         199.1491, 198.9163, 199.2759],
        [198.9222, 198.4034, 198.9310, 198.8263, 199.1095, 198.8734, 199.0456,
         198.9774, 198.7437, 199.1027],
        [199.0948, 198.5772, 199.1045, 199.0013, 199.2842, 199.0488, 199.2191,
         199.1511, 198.9171, 199.2770],
        [198.8641, 198.3443, 198.8720, 198.7686, 199.0517, 198.8142, 198.9851,
         198.9179, 198.6830, 199.0430],
        [198.8909, 198.3727, 198.9004, 198.7972, 199.0805, 198.8429, 199.0160,
         198.9463, 198.7144, 199.0741],
        [199.0156, 198

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.8860, 198.9676, 198.5142, 199.2034, 198.4694, 198.6939, 199.1967,
         199.4695, 198.8775, 198.9364],
        [198.9857, 199.0659, 198.6149, 199.3040, 198.5697, 198.7950, 199.2953,
         199.5695, 198.9771, 199.0360],
        [198.6825, 198.7649, 198.3110, 199.0017, 198.2666, 198.4912, 198.9936,
         199.2665, 198.6747, 198.7340],
        [198.7062, 198.7875, 198.3345, 199.0252, 198.2891, 198.5140, 199.0162,
         199.2896, 198.6980, 198.7572],
        [198.7741, 198.8561, 198.4038, 199.0935, 198.3574, 198.5848, 199.0853,
         199.3605, 198.7664, 198.8255],
        [198.7715, 198.8534, 198.4016, 199.0910, 198.3540, 198.5828, 199.0827,
         199.3589, 198.7638, 198.8230],
        [198.7062, 198.7881, 198.3362, 199.0262, 198.2884, 198.5173, 199.0172,
         199.2936, 198.6987, 198.7580],
        [198.4807, 198

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.6213, 198.5809, 198.6714, 199.1728, 199.2709, 198.7316, 198.9296,
         198.9100, 198.9522, 198.7249],
        [198.6339, 198.5935, 198.6834, 199.1876, 199.2833, 198.7466, 198.9424,
         198.9227, 198.9641, 198.7355],
        [198.9864, 198.9445, 199.0334, 199.5367, 199.6346, 199.0961, 199.2939,
         199.2731, 199.3138, 199.0890],
        [198.4809, 198.4408, 198.5317, 199.0333, 199.1310, 198.5922, 198.7897,
         198.7700, 198.8127, 198.5842],
        [198.7775, 198.7372, 198.8274, 199.3308, 199.4258, 198.8890, 199.0853,
         199.0664, 199.1070, 198.8797],
        [198.5675, 198.5273, 198.6182, 199.1201, 199.2168, 198.6790, 198.8761,
         198.8563, 198.8984, 198.6707],
        [198.9385, 198.8965, 198.9855, 199.4877, 199.5875, 199.0475, 199.2463,
         199.2250, 199.2666, 199.0419],
        [198.5287, 198

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.6193, 198.6515, 198.8859, 198.4752, 198.6878, 198.2619, 198.7797,
         198.4971, 198.5521, 198.4921],
        [198.7197, 198.7517, 198.9857, 198.5751, 198.7873, 198.3617, 198.8796,
         198.5969, 198.6521, 198.5920],
        [198.8085, 198.8401, 199.0745, 198.6649, 198.8757, 198.4504, 198.9689,
         198.6858, 198.7415, 198.6807],
        [198.9353, 198.9681, 199.2011, 198.7939, 199.0048, 198.5792, 199.0951,
         198.8145, 198.8695, 198.8094],
        [198.7977, 198.8320, 199.0644, 198.6559, 198.8672, 198.4408, 198.9588,
         198.6770, 198.7313, 198.6722],
        [198.8227, 198.8567, 199.0892, 198.6809, 198.8921, 198.4657, 198.9836,
         198.7018, 198.7563, 198.6970],
        [198.8598, 198.8946, 199.1262, 198.7183, 198.9308, 198.5041, 199.0198,
         198.7398, 198.7938, 198.7352],
        [198.8584, 198.8912, 199.1242, 198.7155, 198.92

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.6130, 199.4258, 198.5750, 198.7413, 198.1788, 198.5922, 198.7695,
         199.0281, 198.6291, 199.1812],
        [198.5704, 199.3835, 198.5322, 198.6994, 198.1353, 198.5493, 198.7265,
         198.9857, 198.5864, 199.1395],
        [198.7651, 199.5780, 198.7270, 198.8936, 198.3311, 198.7436, 198.9210,
         199.1803, 198.7816, 199.3339],
        [198.4647, 199.2753, 198.4256, 198.5919, 198.0313, 198.4428, 198.6235,
         198.8802, 198.4808, 199.0326],
        [198.6715, 199.4839, 198.6334, 198.7999, 198.2381, 198.6503, 198.8285,
         199.0870, 198.6880, 199.2401],
        [198.5763, 199.3861, 198.5371, 198.7010, 198.1434, 198.5545, 198.7346,
         198.9905, 198.5924, 199.1417],
        [198.5287, 199.3423, 198.4908, 198.6577, 198.0932, 198.5082, 198.6846,
         198.9438, 198.5444, 199.0975],
        [198.3190, 199

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.8363, 198.3766, 198.5333, 198.5192, 198.5998, 198.2683, 198.7834,
         198.7233, 199.4660, 198.2688],
        [198.0385, 198.5795, 198.7360, 198.7230, 198.8009, 198.4718, 198.9869,
         198.9244, 199.6697, 198.4725],
        [197.9542, 198.4953, 198.6511, 198.6377, 198.7169, 198.3877, 198.9029,
         198.8403, 199.5853, 198.3870],
        [198.3357, 198.8766, 199.0320, 199.0197, 199.0977, 198.7678, 199.2827,
         199.2203, 199.9654, 198.7701],
        [198.0949, 198.6361, 198.7942, 198.7817, 198.8567, 198.5286, 199.0436,
         198.9807, 199.7271, 198.5311],
        [197.7833, 198.3235, 198.4826, 198.4688, 198.5465, 198.2151, 198.7303,
         198.6709, 199.4137, 198.2181],
        [198.1293, 198.6703, 198.8263, 198.8134, 198.8916, 198.5633, 199.0778,
         199.0146, 199.7609, 198.5634],
        [198.2127, 198

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.9975, 198.5605, 198.5396, 198.1279, 198.9054, 198.7204, 198.4894,
         198.4521, 198.8249, 198.9673],
        [197.9738, 198.5367, 198.5157, 198.1026, 198.8815, 198.6969, 198.4653,
         198.4292, 198.8010, 198.9439],
        [197.9342, 198.5000, 198.4781, 198.0672, 198.8466, 198.6590, 198.4288,
         198.3907, 198.7651, 198.9055],
        [198.0331, 198.5951, 198.5739, 198.1620, 198.9398, 198.7549, 198.5243,
         198.4880, 198.8593, 199.0023],
        [198.1471, 198.7105, 198.6885, 198.2786, 199.0554, 198.8695, 198.6399,
         198.6019, 198.9749, 199.1163],
        [197.9250, 198.4882, 198.4665, 198.0553, 198.8339, 198.6475, 198.4174,
         198.3808, 198.7528, 198.8952],
        [198.0136, 198.5791, 198.5576, 198.1461, 198.9251, 198.7386, 198.5077,
         198.4695, 198.8441, 198.9845],
        [197.7151, 198

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.4052, 197.8489, 198.1590, 197.9561, 198.4518, 198.5341, 198.8583,
         199.0528, 198.3499, 198.2841],
        [198.2962, 197.7406, 198.0499, 197.8452, 198.3439, 198.4231, 198.7503,
         198.9447, 198.2420, 198.1727],
        [198.5565, 198.0003, 198.3094, 198.1059, 198.6037, 198.6828, 199.0107,
         199.2041, 198.5009, 198.4337],
        [198.5586, 198.0012, 198.3119, 198.1088, 198.6044, 198.6861, 199.0106,
         199.2050, 198.5040, 198.4359],
        [198.4993, 197.9419, 198.2526, 198.0498, 198.5448, 198.6274, 198.9513,
         199.1461, 198.4451, 198.3766],
        [198.7974, 198.2398, 198.5494, 198.3476, 198.8429, 198.9234, 199.2501,
         199.4438, 198.7424, 198.6739],
        [198.5728, 198.0164, 198.3255, 198.1225, 198.6196, 198.6992, 199.0269,
         199.2206, 198.5178, 198.4496],
        [198.5615, 198

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.2631, 198.8046, 198.7926, 197.8693, 198.4328, 198.6111, 198.7616,
         197.7847, 197.9272, 198.3609],
        [198.5174, 199.0603, 199.0478, 198.1252, 198.6861, 198.8673, 199.0164,
         198.0387, 198.1833, 198.6180],
        [198.4267, 198.9707, 198.9580, 198.0342, 198.5958, 198.7786, 198.9256,
         197.9478, 198.0933, 198.5285],
        [198.2586, 198.8043, 198.7910, 197.8687, 198.4282, 198.6125, 198.7604,
         197.7812, 197.9277, 198.3628],
        [198.3438, 198.8857, 198.8736, 197.9503, 198.5134, 198.6921, 198.8422,
         197.8653, 198.0083, 198.4421],
        [198.5023, 199.0457, 199.0328, 198.1099, 198.6714, 198.8518, 199.0009,
         198.0236, 198.1684, 198.6023],
        [198.1977, 198.7402, 198.7277, 197.8042, 198.3682, 198.5457, 198.6961,
         197.7197, 197.8626, 198.2953],
        [198.2859, 198.8316, 198.8178, 197.8943, 198.45

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.4789, 198.4077, 198.9212, 198.7948, 198.0812, 198.9832, 198.3095,
         198.4082, 198.6702, 198.9287],
        [198.3827, 198.3126, 198.8253, 198.6986, 197.9858, 198.8866, 198.2134,
         198.3121, 198.5754, 198.8328],
        [198.4327, 198.3614, 198.8750, 198.7484, 198.0347, 198.9369, 198.2633,
         198.3622, 198.6240, 198.8826],
        [198.3107, 198.2395, 198.7531, 198.6273, 197.9133, 198.8164, 198.1399,
         198.2391, 198.5031, 198.7601],
        [198.4000, 198.3300, 198.8426, 198.7159, 198.0032, 198.9043, 198.2307,
         198.3295, 198.5927, 198.8503],
        [198.4127, 198.3437, 198.8556, 198.7295, 198.0169, 198.9206, 198.2427,
         198.3422, 198.6065, 198.8633],
        [198.3984, 198.3266, 198.8405, 198.7138, 197.9998, 198.9017, 198.2288,
         198.3277, 198.5894, 198.8480],
        [198.0847, 198

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.4462, 198.7030, 198.9505, 197.7992, 199.1443, 198.7750, 199.4945,
         198.7936, 198.6768, 198.4819],
        [198.3651, 198.6219, 198.8691, 197.7182, 199.0638, 198.6950, 199.4143,
         198.7139, 198.5970, 198.4013],
        [198.3800, 198.6405, 198.8879, 197.7336, 199.0814, 198.7103, 199.4312,
         198.7295, 198.6135, 198.4179],
        [198.3888, 198.6485, 198.8963, 197.7418, 199.0905, 198.7181, 199.4393,
         198.7387, 198.6227, 198.4269],
        [198.3691, 198.6287, 198.8756, 197.7234, 199.0687, 198.7010, 199.4204,
         198.7182, 198.6017, 198.4060],
        [198.3010, 198.5577, 198.8044, 197.6537, 198.9988, 198.6290, 199.3488,
         198.6477, 198.5311, 198.3360],
        [198.5163, 198.7764, 199.0241, 197.8706, 199.2178, 198.8490, 199.5687,
         198.8675, 198.7509, 198.5548],
        [198.2929, 198

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.4514, 198.8566, 198.7653, 198.4664, 198.5867, 198.6234, 198.1067,
         198.2601, 198.4280, 199.2739],
        [198.3227, 198.7302, 198.6372, 198.3379, 198.4563, 198.4966, 197.9780,
         198.1322, 198.3011, 199.1470],
        [198.5314, 198.9365, 198.8455, 198.5471, 198.6672, 198.7034, 198.1869,
         198.3408, 198.5087, 199.3551],
        [198.2806, 198.6900, 198.5968, 198.2975, 198.4148, 198.4559, 197.9377,
         198.0921, 198.2617, 199.1079],
        [198.2564, 198.6651, 198.5731, 198.2736, 198.3905, 198.4302, 197.9130,
         198.0667, 198.2370, 199.0838],
        [198.2238, 198.6337, 198.5403, 198.2408, 198.3575, 198.3992, 197.8807,
         198.0350, 198.2051, 199.0514],
        [198.4193, 198.8265, 198.7354, 198.4362, 198.5536, 198.5919, 198.0751,
         198.2292, 198.3988, 199.2462],
        [198.3369, 198

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.9134, 199.5031, 198.7727, 198.8564, 198.0378, 198.5551, 198.5600,
         198.7343, 198.8621, 198.8888],
        [198.7317, 199.3216, 198.5909, 198.6750, 197.8559, 198.3736, 198.3781,
         198.5541, 198.6824, 198.7071],
        [198.6343, 199.2234, 198.4935, 198.5756, 197.7587, 198.2751, 198.2790,
         198.4538, 198.5830, 198.6090],
        [198.5728, 199.1608, 198.4312, 198.5151, 197.6962, 198.2144, 198.2163,
         198.3921, 198.5213, 198.5469],
        [198.4001, 198.9889, 198.2591, 198.3419, 197.5240, 198.0414, 198.0443,
         198.2207, 198.3504, 198.3748],
        [198.5657, 199.1549, 198.4257, 198.5070, 197.6902, 198.2066, 198.2127,
         198.3846, 198.5131, 198.5413],
        [198.7052, 199.2931, 198.5634, 198.6473, 197.8288, 198.3466, 198.3478,
         198.5235, 198.6530, 198.6790],
        [198.5457, 199

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 9, 9, 9, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.3416, 198.0542, 197.9657, 198.0423, 198.3377, 198.1012, 198.0205,
         197.9315, 198.0127, 198.0985],
        [198.4338, 198.1467, 198.0585, 198.1342, 198.4290, 198.1928, 198.1131,
         198.0229, 198.1050, 198.1903],
        [198.5702, 198.2835, 198.1948, 198.2708, 198.5657, 198.3296, 198.2498,
         198.1584, 198.2411, 198.3260],
        [198.6195, 198.3329, 198.2444, 198.3199, 198.6142, 198.3786, 198.2995,
         198.2090, 198.2907, 198.3765],
        [198.7771, 198.4915, 198.4031, 198.4780, 198.7718, 198.5363, 198.4577,
         198.3665, 198.4487, 198.5342],
        [198.5263, 198.2405, 198.1520, 198.2276, 198.5222, 198.2858, 198.2061,
         198.1172, 198.1982, 198.2844],
        [198.6976, 198.4110, 198.3226, 198.3978, 198.6916, 198.4564, 198.3778,
         198.2866, 198.3687, 198.4543],
        [198.2725, 197

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.8382, 197.7377, 197.9109, 197.8733, 197.9076, 198.7519, 198.9573,
         197.3807, 197.7893, 198.1673],
        [198.1518, 198.0524, 198.2254, 198.1852, 198.2212, 199.0667, 199.2676,
         197.6964, 198.1023, 198.4772],
        [198.3971, 198.2979, 198.4707, 198.4322, 198.4659, 199.3125, 199.5141,
         197.9425, 198.3498, 198.7241],
        [198.1350, 198.0351, 198.2081, 198.1712, 198.2043, 199.0500, 199.2544,
         197.6789, 198.0883, 198.4648],
        [198.2323, 198.1330, 198.3060, 198.2663, 198.3017, 199.1475, 199.3484,
         197.7773, 198.1835, 198.5582],
        [198.0528, 197.9530, 198.1267, 198.0863, 198.1224, 198.9680, 199.1692,
         197.5966, 198.0031, 198.3793],
        [198.3038, 198.2043, 198.3759, 198.3402, 198.3717, 199.2169, 199.4225,
         197.8488, 198.2574, 198.6315],
        [198.1176, 198

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.5174, 198.2342, 198.1956, 198.4487, 198.8966, 198.5378, 198.2946,
         198.2095, 198.4557, 198.2695],
        [198.4267, 198.1398, 198.1039, 198.3537, 198.8014, 198.4431, 198.2015,
         198.1150, 198.3608, 198.1751],
        [198.5639, 198.2769, 198.2410, 198.4909, 198.9381, 198.5802, 198.3391,
         198.2522, 198.4975, 198.3122],
        [198.4481, 198.1638, 198.1257, 198.3775, 198.8264, 198.4677, 198.2240,
         198.1387, 198.3861, 198.1987],
        [198.4879, 198.2012, 198.1642, 198.4160, 198.8627, 198.5062, 198.2639,
         198.1776, 198.4234, 198.2375],
        [198.2406, 197.9537, 197.9175, 198.1675, 198.6159, 198.2575, 198.0148,
         197.9289, 198.1759, 197.9892],
        [198.3938, 198.1080, 198.0708, 198.3233, 198.7696, 198.4126, 198.1705,
         198.0844, 198.3296, 198.1445],
        [198.3457, 198.0583, 198.0222, 198.2724, 198.72

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.0854, 197.8405, 197.9643, 198.2371, 198.2535, 197.5740, 197.9213,
         197.9366, 198.1739, 198.3724],
        [198.3480, 198.1005, 198.2249, 198.4996, 198.5133, 197.8362, 198.1836,
         198.1984, 198.4343, 198.6332],
        [198.3241, 198.0796, 198.2034, 198.4766, 198.4922, 197.8140, 198.1608,
         198.1758, 198.4131, 198.6112],
        [198.2025, 197.9563, 198.0814, 198.3562, 198.3683, 197.6914, 198.0401,
         198.0558, 198.2903, 198.4886],
        [198.2268, 197.9786, 198.1031, 198.3781, 198.3913, 197.7141, 198.0619,
         198.0769, 198.3123, 198.5115],
        [198.1988, 197.9549, 198.0789, 198.3521, 198.3674, 197.6887, 198.0364,
         198.0518, 198.2887, 198.4865],
        [198.3868, 198.1396, 198.2640, 198.5384, 198.5524, 197.8754, 198.2224,
         198.2371, 198.4733, 198.6721],
        [198.1942, 197

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.8279, 198.1514, 198.7498, 198.2512, 198.9962, 198.5926, 198.4787,
         197.5076, 198.3996, 198.3471],
        [197.8421, 198.1648, 198.7663, 198.2647, 199.0103, 198.6072, 198.4944,
         197.5228, 198.4140, 198.3600],
        [197.4573, 197.7776, 198.3786, 197.8787, 198.6217, 198.2206, 198.1091,
         197.1361, 198.0285, 197.9732],
        [197.6902, 198.0122, 198.6122, 198.1112, 198.8568, 198.4546, 198.3413,
         197.3695, 198.2622, 198.2072],
        [197.6480, 197.9703, 198.5721, 198.0704, 198.8161, 198.4131, 198.3000,
         197.3287, 198.2197, 198.1659],
        [197.7810, 198.1036, 198.7055, 198.2035, 198.9493, 198.5463, 198.4334,
         197.4619, 198.3530, 198.2990],
        [197.6815, 198.0043, 198.6058, 198.1040, 198.8502, 198.4469, 198.3334,
         197.3624, 198.2535, 198.1999],
        [197.7718, 198.0945, 198.6918, 198.1940, 198.93

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.4555, 199.2292, 198.0276, 198.1286, 198.0494, 197.6810, 198.2222,
         198.3420, 197.8373, 197.3864],
        [198.3844, 199.1572, 197.9572, 198.0561, 197.9765, 197.6093, 198.1517,
         198.2722, 197.7658, 197.3149],
        [198.5522, 199.3252, 198.1251, 198.2241, 198.1447, 197.7776, 198.3197,
         198.4400, 197.9331, 197.4828],
        [198.4138, 199.1868, 197.9869, 198.0858, 198.0062, 197.6391, 198.1817,
         198.3021, 197.7954, 197.3445],
        [198.5818, 199.3569, 198.1544, 198.2564, 198.1780, 197.8092, 198.3498,
         198.4700, 197.9657, 197.5136],
        [198.3801, 199.1547, 197.9529, 198.0539, 197.9744, 197.6060, 198.1485,
         198.2685, 197.7632, 197.3113],
        [198.3665, 199.1397, 197.9390, 198.0387, 197.9593, 197.5915, 198.1335,
         198.2537, 197.7483, 197.2971],
        [198.3836, 199

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.4492, 198.1063, 198.1905, 198.1664, 198.3692, 197.6149, 197.6601,
         198.3902, 198.3457, 198.3444],
        [197.3630, 198.0215, 198.1043, 198.0790, 198.2833, 197.5278, 197.5745,
         198.3029, 198.2594, 198.2572],
        [197.6901, 198.3483, 198.4296, 198.4058, 198.6107, 197.8554, 197.9029,
         198.6288, 198.5868, 198.5848],
        [197.3691, 198.0274, 198.1084, 198.0840, 198.2901, 197.5329, 197.5814,
         198.3099, 198.2668, 198.2653],
        [197.6143, 198.2722, 198.3563, 198.3318, 198.5358, 197.7815, 197.8268,
         198.5541, 198.5116, 198.5096],
        [197.4523, 198.1113, 198.1926, 198.1674, 198.3733, 197.6169, 197.6648,
         198.3915, 198.3495, 198.3471],
        [197.3607, 198.0186, 198.1023, 198.0774, 198.2814, 197.5262, 197.5721,
         198.3015, 198.2577, 198.2560],
        [197.5675, 198

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.0266, 198.8789, 197.8534, 197.8493, 198.2063, 198.1630, 197.6429,
         197.7628, 197.8752, 197.8470],
        [198.3231, 199.1741, 198.1517, 198.1412, 198.4992, 198.4585, 197.9378,
         198.0617, 198.1696, 198.1418],
        [198.2122, 199.0647, 198.0381, 198.0352, 198.3931, 198.3475, 197.8280,
         197.9493, 198.0612, 198.0325],
        [198.0765, 198.9277, 197.9032, 197.8955, 198.2538, 198.2110, 197.6906,
         197.8135, 197.9230, 197.8948],
        [198.4222, 199.2745, 198.2502, 198.2436, 198.6012, 198.5585, 198.0384,
         198.1609, 198.2708, 198.2427],
        [198.1020, 198.9545, 197.9292, 197.9242, 198.2816, 198.2387, 197.7184,
         197.8387, 197.9507, 197.9225],
        [198.1680, 199.0210, 197.9947, 197.9886, 198.3471, 198.3044, 197.7838,
         197.9056, 198.0165, 197.9884],
        [198.1528, 199.0042, 197.9810, 197.9737, 198.33

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.4731, 198.2992, 199.0973, 198.3714, 198.9603, 197.9478, 197.9700,
         197.8164, 198.6420, 198.9804],
        [198.0772, 197.9037, 198.7007, 197.9749, 198.5638, 197.5513, 197.5728,
         197.4205, 198.2443, 198.5835],
        [198.3541, 198.1824, 198.9781, 198.2515, 198.8403, 197.8282, 197.8515,
         197.6969, 198.5224, 198.8632],
        [198.2570, 198.0820, 198.8800, 198.1542, 198.7441, 197.7300, 197.7513,
         197.5986, 198.4259, 198.7638],
        [198.2036, 198.0317, 198.8277, 198.1013, 198.6899, 197.6782, 197.7009,
         197.5472, 198.3709, 198.7113],
        [198.2818, 198.1069, 198.9049, 198.1791, 198.7690, 197.7549, 197.7762,
         197.6234, 198.4508, 198.7885],
        [198.2804, 198.1101, 198.9045, 198.1775, 198.7660, 197.7547, 197.7786,
         197.6234, 198.4478, 198.7902],
        [198.3631, 198

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.1499, 198.3487, 197.6514, 197.9919, 197.9115, 197.9983, 197.0925,
         198.0330, 198.6098, 197.9877],
        [197.4136, 198.6116, 197.9146, 198.2565, 198.1747, 198.2623, 197.3578,
         198.2958, 198.8743, 198.2506],
        [197.4884, 198.6869, 197.9913, 198.3303, 198.2487, 198.3374, 197.4319,
         198.3696, 198.9493, 198.3257],
        [197.1661, 198.3649, 197.6679, 198.0078, 197.9273, 198.0148, 197.1086,
         198.0490, 198.6261, 198.0039],
        [197.4077, 198.6063, 197.9101, 198.2493, 198.1675, 198.2576, 197.3511,
         198.2895, 198.8688, 198.2450],
        [197.3392, 198.5381, 197.8382, 198.1808, 198.1003, 198.1875, 197.2814,
         198.2218, 198.7985, 198.1762],
        [197.4078, 198.6065, 197.9061, 198.2494, 198.1685, 198.2564, 197.3503,
         198.2902, 198.8673, 198.2445],
        [197.5254, 198.7241, 198.0253, 198.3672, 198.28

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.1101, 198.2419, 198.1227, 197.8318, 198.0282, 198.3925, 197.9502,
         197.2936, 197.2356, 197.9890],
        [197.8461, 197.9802, 197.8620, 197.5688, 197.7660, 198.1315, 197.6873,
         197.0312, 196.9738, 197.7275],
        [198.1320, 198.2649, 198.1457, 197.8533, 198.0493, 198.4171, 197.9712,
         197.3147, 197.2580, 198.0133],
        [198.1404, 198.2732, 198.1543, 197.8619, 198.0582, 198.4248, 197.9806,
         197.3234, 197.2665, 198.0213],
        [198.1708, 198.3044, 198.1865, 197.8940, 198.0906, 198.4554, 198.0135,
         197.3563, 197.2993, 198.0510],
        [197.9050, 198.0395, 197.9216, 197.6279, 197.8254, 198.1899, 197.7479,
         197.0902, 197.0330, 197.7863],
        [198.1163, 198.2484, 198.1292, 197.8381, 198.0344, 198.3994, 197.9563,
         197.2999, 197.2423, 197.9956],
        [197.9372, 198.0706, 197.9521, 197.6597, 197.85

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.2610, 197.7682, 198.6763, 198.4340, 198.0276, 197.3103, 197.8334,
         198.2146, 198.0845, 197.9090],
        [197.9646, 197.4708, 198.3791, 198.1357, 197.7317, 197.0122, 197.5371,
         197.9201, 197.7882, 197.6127],
        [198.1790, 197.6851, 198.5949, 198.3510, 197.9457, 197.2279, 197.7518,
         198.1308, 198.0006, 197.8252],
        [197.9862, 197.4924, 198.4007, 198.1576, 197.7532, 197.0340, 197.5587,
         197.9413, 197.8097, 197.6342],
        [198.1719, 197.6787, 198.5872, 198.3440, 197.9389, 197.2217, 197.7447,
         198.1255, 197.9954, 197.8197],
        [198.2428, 197.7496, 198.6593, 198.4143, 198.0113, 197.2938, 197.8163,
         198.1966, 198.0667, 197.8907],
        [198.3575, 197.8644, 198.7743, 198.5301, 198.1250, 197.4084, 197.9308,
         198.3092, 198.1801, 198.0043],
        [198.3913, 197

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.1086, 197.8722, 198.1573, 198.1087, 198.1957, 198.6787, 197.9137,
         197.8820, 197.3364, 198.0214],
        [197.9785, 197.7429, 198.0293, 197.9794, 198.0665, 198.5515, 197.7841,
         197.7516, 197.2036, 197.8894],
        [198.1321, 197.8983, 198.1834, 198.1330, 198.2209, 198.7051, 197.9389,
         197.9029, 197.3593, 198.0451],
        [198.0382, 197.8038, 198.0897, 198.0393, 198.1268, 198.6115, 197.8446,
         197.8099, 197.2643, 197.9503],
        [198.0053, 197.7705, 198.0564, 198.0062, 198.0936, 198.5783, 197.8113,
         197.7772, 197.2309, 197.9170],
        [197.9951, 197.7586, 198.0442, 197.9954, 198.0825, 198.5656, 197.8002,
         197.7686, 197.2224, 197.9087],
        [197.9212, 197.6850, 197.9714, 197.9220, 198.0090, 198.4935, 197.7265,
         197.6948, 197.1468, 197.8332],
        [198.1160, 197

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.2819, 198.2729, 198.2390, 197.8691, 197.9436, 198.2632, 197.9769,
         198.9456, 198.4992, 198.8596],
        [198.4188, 198.4107, 198.3766, 198.0068, 198.0813, 198.4010, 198.1144,
         199.0828, 198.6378, 198.9973],
        [198.1003, 198.0909, 198.0564, 197.6853, 197.7607, 198.0800, 197.7939,
         198.7629, 198.3158, 198.6765],
        [198.1518, 198.1425, 198.1080, 197.7371, 197.8123, 198.1317, 197.8455,
         198.8145, 198.3670, 198.7281],
        [198.2002, 198.1907, 198.1571, 197.7868, 197.8616, 198.1811, 197.8951,
         198.8638, 198.4183, 198.7777],
        [198.3289, 198.3209, 198.2865, 197.9165, 197.9907, 198.3109, 198.0238,
         198.9923, 198.5456, 198.9068],
        [197.9965, 197.9872, 197.9531, 197.5814, 197.6569, 197.9766, 197.6902,
         198.6590, 198.2129, 198.5728],
        [198.3184, 198

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.5817, 197.9973, 197.5584, 198.7900, 198.4428, 197.8895, 197.8475,
         197.9802, 197.6832, 198.1182],
        [197.6172, 198.0343, 197.5951, 198.8266, 198.4794, 197.9279, 197.8833,
         198.0173, 197.7185, 198.1544],
        [197.5793, 197.9966, 197.5587, 198.7919, 198.4448, 197.8910, 197.8495,
         197.9819, 197.6820, 198.1196],
        [197.7334, 198.1510, 197.7125, 198.9447, 198.5971, 198.0460, 198.0011,
         198.1350, 197.8350, 198.2714],
        [197.5859, 198.0023, 197.5645, 198.7971, 198.4487, 197.8959, 197.8540,
         197.9864, 197.6884, 198.1235],
        [197.4657, 197.8825, 197.4443, 198.6768, 198.3281, 197.7757, 197.7333,
         197.8661, 197.5682, 198.0029],
        [197.5237, 197.9409, 197.5014, 198.7330, 198.3863, 197.8339, 197.7900,
         197.9240, 197.6250, 198.0615],
        [197.5544, 197.9718, 197.5337, 198.7670, 198.42

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.8227, 197.7390, 197.8974, 198.1207, 198.6652, 198.7791, 197.3207,
         197.9222, 197.6331, 199.0193],
        [197.7839, 197.6998, 197.8582, 198.0816, 198.6260, 198.7400, 197.2814,
         197.8828, 197.5941, 198.9800],
        [197.5573, 197.4736, 197.6321, 197.8566, 198.3987, 198.5136, 197.0561,
         197.6571, 197.3676, 198.7541],
        [197.7886, 197.7042, 197.8627, 198.0874, 198.6303, 198.7444, 197.2874,
         197.8886, 197.5980, 198.9854],
        [197.6285, 197.5447, 197.7032, 197.9268, 198.4711, 198.5848, 197.1265,
         197.7281, 197.4388, 198.8251],
        [197.8005, 197.7169, 197.8752, 198.0988, 198.6427, 198.7569, 197.2986,
         197.9002, 197.6109, 198.9971],
        [197.4680, 197.3849, 197.5434, 197.7678, 198.3102, 198.4248, 196.9673,
         197.5687, 197.2786, 198.6656],
        [197.6701, 197.5860, 197.7444, 197.9682, 198.51

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.7981, 197.8351, 197.7030, 197.8608, 197.7985, 197.6056, 197.7658,
         198.0448, 197.6025, 197.9721],
        [198.8618, 197.8980, 197.7650, 197.9257, 197.8588, 197.6678, 197.8260,
         198.1059, 197.6641, 198.0354],
        [198.7973, 197.8348, 197.7027, 197.8607, 197.7974, 197.6042, 197.7641,
         198.0433, 197.6006, 197.9706],
        [198.8847, 197.9220, 197.7901, 197.9471, 197.8854, 197.6921, 197.8527,
         198.1317, 197.6895, 198.0587],
        [198.7729, 197.8089, 197.6771, 197.8358, 197.7726, 197.5806, 197.7402,
         198.0189, 197.5766, 197.9467],
        [199.1316, 198.1674, 198.0360, 198.1940, 198.1303, 197.9388, 198.0977,
         198.3769, 197.9356, 198.3052],
        [198.6812, 197.7187, 197.5860, 197.7450, 197.6808, 197.4880, 197.6478,
         197.9272, 197.4844, 197.8550],
        [198.9201, 197

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.9296, 197.9197, 197.5828, 198.0082, 197.0828, 197.8171, 197.6036,
         197.6832, 198.4038, 197.9232],
        [197.9334, 197.9238, 197.5871, 198.0125, 197.0866, 197.8208, 197.6074,
         197.6873, 198.4074, 197.9269],
        [198.0943, 198.0852, 197.7492, 198.1761, 197.2471, 197.9810, 197.7690,
         197.8506, 198.5670, 198.0878],
        [197.9836, 197.9742, 197.6369, 198.0648, 197.1365, 197.8705, 197.6588,
         197.7398, 198.4571, 197.9780],
        [197.7406, 197.7316, 197.3945, 197.8223, 196.8938, 197.6280, 197.4143,
         197.4954, 198.2145, 197.7345],
        [197.9745, 197.9652, 197.6282, 198.0546, 197.1282, 197.8620, 197.6496,
         197.7295, 198.4490, 197.9691],
        [198.1121, 198.1018, 197.7649, 198.1908, 197.2646, 197.9990, 197.7871,
         197.8672, 198.5854, 198.1058],
        [197.9120, 197

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.6818, 197.6923, 198.5299, 197.8282, 197.6527, 197.8019, 197.1406,
         197.8683, 197.1428, 198.6047],
        [197.9023, 197.9114, 198.7483, 198.0470, 197.8710, 198.0222, 197.3585,
         198.0867, 197.3624, 198.8226],
        [197.9387, 197.9494, 198.7879, 198.0863, 197.9084, 198.0590, 197.3991,
         198.1257, 197.4010, 198.8618],
        [197.5452, 197.5545, 198.3921, 197.6906, 197.5150, 197.6646, 197.0010,
         197.7315, 197.0045, 198.4661],
        [197.5821, 197.5926, 198.4310, 197.7292, 197.5528, 197.7019, 197.0413,
         197.7698, 197.0429, 198.5052],
        [197.7011, 197.7096, 198.5467, 197.8454, 197.6695, 197.8206, 197.1557,
         197.8860, 197.1600, 198.6205],
        [197.6304, 197.6392, 198.4776, 197.7764, 197.5992, 197.7495, 197.0858,
         197.8170, 197.0898, 198.5506],
        [197.9031, 197.9130, 198.7511, 198.0496, 197.87

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.0826, 197.6656, 198.0467, 197.6822, 198.1904, 197.1934, 198.5878,
         198.0878, 198.1734, 198.0791],
        [197.9469, 197.5291, 197.9104, 197.5439, 198.0558, 197.0569, 198.4542,
         197.9512, 198.0403, 197.9431],
        [197.9525, 197.5335, 197.9163, 197.5481, 198.0595, 197.0617, 198.4585,
         197.9567, 198.0434, 197.9465],
        [197.8750, 197.4573, 197.8385, 197.4724, 197.9839, 196.9849, 198.3819,
         197.8793, 197.9683, 197.8713],
        [197.9250, 197.5072, 197.8885, 197.5214, 198.0338, 197.0351, 198.4325,
         197.9296, 198.0187, 197.9213],
        [197.8898, 197.4718, 197.8538, 197.4883, 197.9959, 196.9994, 198.3928,
         197.8945, 197.9792, 197.8852],
        [197.7924, 197.3741, 197.7557, 197.3878, 197.9008, 196.9019, 198.2995,
         197.7966, 197.8860, 197.7880],
        [197.6495, 197

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.0515, 198.7419, 197.7087, 197.7414, 197.5682, 197.5477, 197.7952,
         197.5891, 197.9724, 197.8661],
        [197.9456, 198.6374, 197.6076, 197.6402, 197.4635, 197.4447, 197.6934,
         197.4853, 197.8706, 197.7640],
        [197.8293, 198.5206, 197.4897, 197.5227, 197.3470, 197.3266, 197.5762,
         197.3687, 197.7535, 197.6467],
        [197.9206, 198.6122, 197.5822, 197.6144, 197.4377, 197.4188, 197.6666,
         197.4595, 197.8455, 197.7392],
        [198.1467, 198.8371, 197.8049, 197.8382, 197.6650, 197.6429, 197.8921,
         197.6853, 198.0680, 197.9613],
        [198.1290, 198.8205, 197.7907, 197.8231, 197.6469, 197.6273, 197.8750,
         197.6680, 198.0533, 197.9471],
        [198.0688, 198.7602, 197.7300, 197.7628, 197.5870, 197.5675, 197.8165,
         197.6084, 197.9928, 197.8859],
        [198.0688, 198

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.4883, 197.7193, 197.4407, 197.3191, 196.9936, 198.5784, 197.5803,
         198.5003, 197.5440, 197.8119],
        [197.7675, 197.9989, 197.7215, 197.5991, 197.2750, 198.8585, 197.8591,
         198.7789, 197.8248, 198.0898],
        [197.4245, 197.6570, 197.3775, 197.2554, 196.9303, 198.5159, 197.5172,
         198.4364, 197.4809, 197.7485],
        [197.5604, 197.7914, 197.5128, 197.3893, 197.0651, 198.6507, 197.6506,
         198.5722, 197.6147, 197.8848],
        [197.5779, 197.8089, 197.5303, 197.4067, 197.0825, 198.6683, 197.6681,
         198.5896, 197.6321, 197.9022],
        [197.4885, 197.7201, 197.4410, 197.3179, 196.9937, 198.5791, 197.5797,
         198.5006, 197.5433, 197.8132],
        [197.4354, 197.6672, 197.3874, 197.2623, 196.9382, 198.5263, 197.5248,
         198.4468, 197.4882, 197.7610],
        [197.5599, 197

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.2499, 197.6858, 197.6303, 197.7567, 197.8790, 197.6859, 196.9375,
         197.6326, 197.3187, 197.5954],
        [197.2859, 197.7231, 197.6671, 197.7920, 197.9155, 197.7217, 196.9717,
         197.6697, 197.3554, 197.6315],
        [197.1090, 197.5461, 197.4889, 197.6184, 197.7400, 197.5445, 196.7956,
         197.4943, 197.1776, 197.4525],
        [197.2870, 197.7240, 197.6690, 197.7916, 197.9159, 197.7233, 196.9742,
         197.6687, 197.3569, 197.6340],
        [197.2841, 197.7208, 197.6637, 197.7926, 197.9145, 197.7197, 196.9702,
         197.6693, 197.3532, 197.6281],
        [197.2163, 197.6529, 197.5974, 197.7227, 197.8458, 197.6525, 196.9042,
         197.5985, 197.2856, 197.5623],
        [196.9909, 197.4275, 197.3717, 197.4996, 197.6211, 197.4266, 196.6783,
         197.3749, 197.0587, 197.3350],
        [197.1364, 197

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.4897, 197.5669, 197.6947, 197.4580, 197.8508, 197.5955, 198.6829,
         197.7393, 198.6041, 197.4213],
        [197.2386, 197.3168, 197.4434, 197.2074, 197.6005, 197.3456, 198.4328,
         197.4894, 198.3529, 197.1715],
        [197.5227, 197.6007, 197.7279, 197.4923, 197.8849, 197.6289, 198.7171,
         197.7734, 198.6371, 197.4562],
        [197.3923, 197.4720, 197.5978, 197.3630, 197.7554, 197.5006, 198.5881,
         197.6441, 198.5074, 197.3276],
        [197.5711, 197.6499, 197.7767, 197.5420, 197.9342, 197.6782, 198.7659,
         197.8231, 198.6863, 197.5068],
        [197.6194, 197.6980, 197.8251, 197.5899, 197.9820, 197.7263, 198.8141,
         197.8708, 198.7345, 197.5542],
        [197.3218, 197.3996, 197.5268, 197.2906, 197.6834, 197.4287, 198.5149,
         197.5726, 198.4368, 197.2547],
        [197.4440, 197

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.5566, 197.6921, 197.2594, 197.6345, 197.6422, 197.7646, 196.6897,
         197.1556, 198.2763, 196.9308],
        [197.8347, 197.9701, 197.5386, 197.9142, 197.9201, 198.0437, 196.9684,
         197.4361, 198.5550, 197.2114],
        [197.2703, 197.4048, 196.9725, 197.3473, 197.3554, 197.4782, 196.4029,
         196.8686, 197.9893, 196.6442],
        [197.4939, 197.6302, 197.1966, 197.5735, 197.5800, 197.7029, 196.6272,
         197.0936, 198.2134, 196.8682],
        [197.6188, 197.7536, 197.3205, 197.6966, 197.7018, 197.8263, 196.7506,
         197.2176, 198.3380, 196.9949],
        [197.6144, 197.7501, 197.3199, 197.6946, 197.7029, 197.8249, 196.7498,
         197.2171, 198.3352, 196.9896],
        [197.6058, 197.7416, 197.3073, 197.6847, 197.6894, 197.8137, 196.7377,
         197.2047, 198.3248, 196.9812],
        [197.5797, 197

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.0415, 197.3842, 197.5375, 198.0638, 197.3040, 197.2246, 197.2287,
         197.8209, 197.2377, 196.8811],
        [198.0388, 197.3795, 197.5351, 198.0620, 197.2997, 197.2215, 197.2253,
         197.8159, 197.2360, 196.8772],
        [198.0324, 197.3733, 197.5288, 198.0552, 197.2934, 197.2149, 197.2184,
         197.8089, 197.2296, 196.8705],
        [198.0200, 197.3611, 197.5169, 198.0443, 197.2822, 197.2043, 197.2087,
         197.8002, 197.2173, 196.8599],
        [197.9500, 197.2913, 197.4455, 197.9717, 197.2110, 197.1319, 197.1355,
         197.7265, 197.1466, 196.7878],
        [198.1985, 197.5407, 197.6961, 198.2232, 197.4610, 197.3830, 197.3874,
         197.9789, 197.3961, 197.0394],
        [198.3642, 197.7069, 197.8610, 198.3877, 197.6260, 197.5468, 197.5517,
         198.1427, 197.5612, 197.2046],
        [197.8301, 197.1714, 197.3241, 197.8501, 197.09

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.3147, 197.9624, 197.7177, 197.3155, 197.2268, 196.7743, 197.4438,
         197.6253, 197.9377, 197.4830],
        [197.3566, 198.0037, 197.7596, 197.3568, 197.2688, 196.8165, 197.4854,
         197.6670, 197.9795, 197.5247],
        [197.1678, 197.8096, 197.5663, 197.1643, 197.0764, 196.6262, 197.2921,
         197.4757, 197.7896, 197.3331],
        [197.3524, 197.9957, 197.7524, 197.3496, 197.2624, 196.8120, 197.4776,
         197.6606, 197.9740, 197.5188],
        [197.3701, 198.0114, 197.7693, 197.3656, 197.2796, 196.8298, 197.4939,
         197.6775, 197.9911, 197.5356],
        [197.2989, 197.9440, 197.6993, 197.2984, 197.2094, 196.7574, 197.4256,
         197.6100, 197.9229, 197.4660],
        [197.4945, 198.1365, 197.8916, 197.4916, 197.4030, 196.9533, 197.6168,
         197.8038, 198.1168, 197.6604],
        [197.4718, 198.1171, 197.8719, 197.4714, 197.38

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.0002, 197.2012, 197.6377, 197.6793, 197.4822, 197.3406, 197.6318,
         197.4223, 197.5678, 197.5663],
        [196.9185, 197.1188, 197.5561, 197.5974, 197.4003, 197.2585, 197.5500,
         197.3401, 197.4859, 197.4844],
        [196.9571, 197.1576, 197.5918, 197.6353, 197.4368, 197.3001, 197.5864,
         197.3782, 197.5246, 197.5232],
        [197.0831, 197.2845, 197.7182, 197.7620, 197.5638, 197.4262, 197.7137,
         197.5044, 197.6512, 197.6494],
        [196.7166, 196.9153, 197.3533, 197.3943, 197.1979, 197.0598, 197.3448,
         197.1382, 197.2839, 197.2822],
        [196.9639, 197.1646, 197.6016, 197.6430, 197.4458, 197.3042, 197.5954,
         197.3857, 197.5314, 197.5300],
        [196.8324, 197.0320, 197.4701, 197.5112, 197.3145, 197.1728, 197.4634,
         197.2538, 197.3999, 197.3980],
        [196.8326, 197.0322, 197.4690, 197.5106, 197.31

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.4148, 197.3620, 197.1040, 197.0489, 197.2109, 197.8625, 197.1902,
         198.3584, 197.5927, 197.7591],
        [197.3346, 197.2850, 197.0258, 196.9684, 197.1306, 197.7815, 197.1102,
         198.2783, 197.5155, 197.6818],
        [197.1367, 197.0849, 196.8262, 196.7694, 196.9323, 197.5833, 196.9142,
         198.0792, 197.3157, 197.4814],
        [197.4467, 197.3947, 197.1365, 197.0812, 197.2430, 197.8949, 197.2210,
         198.3909, 197.6256, 197.7916],
        [197.4830, 197.4313, 197.1728, 197.1171, 197.2784, 197.9289, 197.2600,
         198.4255, 197.6609, 197.8282],
        [197.4539, 197.4034, 197.1446, 197.0875, 197.2496, 197.8994, 197.2307,
         198.3967, 197.6329, 197.8007],
        [197.4923, 197.4402, 197.1821, 197.1269, 197.2886, 197.9404, 197.2666,
         198.4364, 197.6711, 197.8372],
        [197.3608, 197

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.5652, 197.4010, 197.2010, 197.1980, 197.5750, 197.5142, 197.1659,
         197.0107, 198.2325, 197.6390],
        [196.6560, 197.4918, 197.2912, 197.2891, 197.6642, 197.6053, 197.2563,
         197.1010, 198.3225, 197.7295],
        [196.6383, 197.4738, 197.2719, 197.2720, 197.6448, 197.5883, 197.2373,
         197.0825, 198.3036, 197.7112],
        [196.7874, 197.6232, 197.4220, 197.4212, 197.7946, 197.7368, 197.3876,
         197.2323, 198.4535, 197.8606],
        [196.6578, 197.4935, 197.2923, 197.2904, 197.6666, 197.6064, 197.2589,
         197.1037, 198.3252, 197.7310],
        [196.5781, 197.4138, 197.2130, 197.2105, 197.5882, 197.5266, 197.1794,
         197.0242, 198.2460, 197.6516],
        [196.5264, 197.3619, 197.1613, 197.1591, 197.5369, 197.4754, 197.1271,
         196.9721, 198.1944, 197.5998],
        [196.5424, 197

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[198.1768, 197.2489, 196.5952, 197.8190, 197.9186, 197.2646, 196.8786,
         197.3036, 197.5581, 197.4724],
        [198.1482, 197.2185, 196.5658, 197.7883, 197.8918, 197.2368, 196.8486,
         197.2748, 197.5309, 197.4437],
        [198.1274, 197.1995, 196.5444, 197.7665, 197.8711, 197.2161, 196.8289,
         197.2541, 197.5098, 197.4233],
        [197.8138, 196.8875, 196.2315, 197.4546, 197.5562, 196.9013, 196.5161,
         196.9410, 197.1960, 197.1096],
        [198.1096, 197.1826, 196.5277, 197.7519, 197.8508, 197.1968, 196.8117,
         197.2368, 197.4907, 197.4053],
        [198.0470, 197.1191, 196.4654, 197.6890, 197.7890, 197.1348, 196.7487,
         197.1738, 197.4287, 197.3425],
        [197.9467, 197.0178, 196.3647, 197.5877, 197.6899, 197.0348, 196.6474,
         197.0738, 197.3294, 197.2421],
        [197.9337, 197

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.2028, 196.5194, 197.3676, 197.0576, 196.9640, 197.1137, 196.7806,
         197.3071, 196.8679, 196.9082],
        [197.3550, 196.6716, 197.5198, 197.2113, 197.1157, 197.2664, 196.9333,
         197.4591, 197.0196, 197.0605],
        [197.4616, 196.7782, 197.6273, 197.3154, 197.2227, 197.3714, 197.0391,
         197.5648, 197.1262, 197.1668],
        [197.3041, 196.6221, 197.4690, 197.1582, 197.0655, 197.2157, 196.8820,
         197.4092, 196.9703, 197.0098],
        [197.3863, 196.7046, 197.5511, 197.2423, 197.1474, 197.2990, 196.9649,
         197.4917, 197.0529, 197.0924],
        [197.4796, 196.7980, 197.6442, 197.3360, 197.2406, 197.3922, 197.0585,
         197.5850, 197.1457, 197.1857],
        [197.5003, 196.8190, 197.6651, 197.3551, 197.2616, 197.4123, 197.0788,
         197.6057, 197.1667, 197.2063],
        [197.0968, 196.4124, 197.2611, 196.9516, 196.85

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.8292, 197.2382, 197.0467, 196.4907, 197.0904, 197.0147, 197.3514,
         196.6473, 196.9951, 197.1969],
        [197.8930, 197.3040, 197.1126, 196.5573, 197.1564, 197.0801, 197.4165,
         196.7119, 197.0603, 197.2607],
        [197.7740, 197.1869, 196.9946, 196.4392, 197.0381, 196.9652, 197.2991,
         196.5941, 196.9419, 197.1432],
        [197.6687, 197.0785, 196.8865, 196.3318, 196.9313, 196.8545, 197.1915,
         196.4866, 196.8350, 197.0363],
        [197.8691, 197.2794, 197.0878, 196.5313, 197.1304, 197.0573, 197.3925,
         196.6879, 197.0362, 197.2374],
        [197.9720, 197.3815, 197.1905, 196.6352, 197.2343, 197.1565, 197.4946,
         196.7901, 197.1389, 197.3390],
        [198.0889, 197.5007, 197.3094, 196.7529, 197.3513, 197.2784, 197.6134,
         196.9086, 197.2572, 197.4574],
        [197.8349, 197.2477, 197.0556, 196.4999, 197.09

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.1910, 197.2730, 197.4286, 197.1682, 198.3557, 197.2469, 197.4165,
         197.2238, 197.3633, 197.1774],
        [196.9901, 197.0733, 197.2283, 196.9687, 198.1554, 197.0471, 197.2171,
         197.0231, 197.1645, 196.9775],
        [197.0764, 197.1602, 197.3143, 197.0562, 198.2424, 197.1336, 197.3048,
         197.1103, 197.2520, 197.0660],
        [197.3237, 197.4061, 197.5611, 197.3028, 198.4893, 197.3792, 197.5491,
         197.3570, 197.4970, 197.3121],
        [196.9896, 197.0733, 197.2282, 196.9699, 198.1560, 197.0469, 197.2178,
         197.0234, 197.1653, 196.9796],
        [196.9073, 196.9901, 197.1471, 196.8869, 198.0742, 196.9644, 197.1348,
         196.9408, 197.0813, 196.8977],
        [197.1128, 197.1962, 197.3506, 197.0920, 198.2786, 197.1697, 197.3406,
         197.1466, 197.2877, 197.1019],
        [196.9269, 197

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.9393, 197.3489, 197.1119, 197.3384, 197.1451, 197.6986, 197.0455,
         197.3718, 197.2394, 197.1986],
        [196.8854, 197.2950, 197.0558, 197.2853, 197.0920, 197.6462, 196.9915,
         197.3188, 197.1864, 197.1450],
        [196.8692, 197.2775, 197.0399, 197.2665, 197.0743, 197.6267, 196.9732,
         197.2992, 197.1681, 197.1284],
        [197.1487, 197.5566, 197.3187, 197.5456, 197.3542, 197.9067, 197.2531,
         197.5793, 197.4474, 197.4066],
        [196.8413, 197.2502, 197.0136, 197.2398, 197.0463, 197.5997, 196.9463,
         197.2724, 197.1409, 197.1005],
        [196.9614, 197.3702, 197.1338, 197.3601, 197.1664, 197.7202, 197.0668,
         197.3929, 197.2610, 197.2201],
        [196.8976, 197.3080, 197.0707, 197.2981, 197.1039, 197.6583, 197.0048,
         197.3316, 197.1987, 197.1573],
        [197.2096, 197

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.3887, 197.8326, 196.6299, 197.3533, 196.9007, 197.0974, 197.3487,
         197.0155, 197.0360, 198.2065],
        [196.2891, 197.7302, 196.5280, 197.2520, 196.7990, 196.9970, 197.2499,
         196.9153, 196.9355, 198.1048],
        [196.3681, 197.8089, 196.6068, 197.3310, 196.8781, 197.0756, 197.3282,
         196.9947, 197.0142, 198.1839],
        [196.2988, 197.7429, 196.5388, 197.2631, 196.8096, 197.0076, 197.2594,
         196.9251, 196.9462, 198.1166],
        [196.3646, 197.8078, 196.6036, 197.3288, 196.8752, 197.0726, 197.3236,
         196.9916, 197.0112, 198.1824],
        [196.1959, 197.6388, 196.4348, 197.1599, 196.7058, 196.9041, 197.1559,
         196.8226, 196.8425, 198.0128],
        [196.3020, 197.7422, 196.5386, 197.2646, 196.8106, 197.0089, 197.2610,
         196.9291, 196.9474, 198.1178],
        [196.4469, 197

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.1670, 196.8707, 197.2290, 196.9437, 197.6263, 196.7675, 197.2916,
         196.8658, 196.8485, 196.4094],
        [196.8807, 196.5838, 196.9419, 196.6567, 197.3390, 196.4806, 197.0051,
         196.5796, 196.5616, 196.1222],
        [197.2521, 196.9558, 197.3140, 197.0301, 197.7117, 196.8533, 197.3772,
         196.9504, 196.9328, 196.4960],
        [197.0987, 196.8010, 197.1600, 196.8747, 197.5574, 196.6981, 197.2228,
         196.7982, 196.7795, 196.3403],
        [197.1965, 196.9022, 197.2581, 196.9727, 197.6551, 196.7991, 197.3219,
         196.8968, 196.8793, 196.4412],
        [197.1899, 196.8966, 197.2516, 196.9681, 197.6488, 196.7945, 197.3164,
         196.8891, 196.8720, 196.4372],
        [197.3545, 197.0576, 197.4155, 197.1318, 197.8136, 196.9560, 197.4798,
         197.0548, 197.0349, 196.5997],
        [197.1749, 196.8805, 197.2362, 196.9526, 197.63

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.1276, 197.0518, 196.8945, 197.2519, 196.8263, 197.2107, 197.0971,
         197.2261, 196.9373, 197.3969],
        [197.0710, 196.9943, 196.8385, 197.1960, 196.7685, 197.1552, 197.0415,
         197.1698, 196.8805, 197.3397],
        [197.3961, 197.3214, 197.1646, 197.5219, 197.0961, 197.4801, 197.3676,
         197.4945, 197.2065, 197.6656],
        [197.3149, 197.2406, 197.0800, 197.4400, 197.0152, 197.3960, 197.2848,
         197.4126, 197.1241, 197.5827],
        [197.1354, 197.0599, 196.9036, 197.2612, 196.8351, 197.2202, 197.1071,
         197.2336, 196.9453, 197.4029],
        [197.1892, 197.1147, 196.9538, 197.3135, 196.8894, 197.2698, 197.1581,
         197.2867, 196.9982, 197.4572],
        [197.2656, 197.1905, 197.0324, 197.3909, 196.9650, 197.3486, 197.2362,
         197.3639, 197.0753, 197.5341],
        [197.3433, 197

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.4702, 196.8407, 196.6501, 196.8181, 196.6958, 196.8542, 197.1517,
         196.8326, 196.6299, 196.1860],
        [196.6054, 196.9713, 196.7848, 196.9541, 196.8304, 196.9895, 197.2862,
         196.9668, 196.7635, 196.3209],
        [196.5613, 196.9285, 196.7423, 196.9087, 196.7853, 196.9438, 197.2407,
         196.9222, 196.7194, 196.2775],
        [196.6570, 197.0239, 196.8375, 197.0042, 196.8813, 197.0396, 197.3366,
         197.0179, 196.8149, 196.3727],
        [196.5598, 196.9257, 196.7392, 196.9086, 196.7847, 196.9440, 197.2406,
         196.9212, 196.7180, 196.2753],
        [196.5044, 196.8719, 196.6850, 196.8524, 196.7289, 196.8876, 197.1844,
         196.8657, 196.6629, 196.2206],
        [196.3182, 196.6862, 196.4984, 196.6665, 196.5424, 196.7024, 196.9997,
         196.6804, 196.4775, 196.0351],
        [196.4237, 196

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.2000, 196.5400, 196.8225, 198.0704, 196.8838, 197.0462, 196.8954,
         196.9896, 197.2270, 197.0246],
        [197.0095, 196.3448, 196.6299, 197.8780, 196.6925, 196.8545, 196.7022,
         196.7966, 197.0331, 196.8289],
        [196.8542, 196.1920, 196.4751, 197.7224, 196.5366, 196.6996, 196.5470,
         196.6418, 196.8791, 196.6756],
        [197.0788, 196.4162, 196.7000, 197.9461, 196.7617, 196.9253, 196.7714,
         196.8665, 197.1026, 196.8989],
        [197.4103, 196.7474, 197.0322, 198.2793, 197.0945, 197.2569, 197.1039,
         197.1992, 197.4343, 197.2309],
        [196.9619, 196.3000, 196.5833, 197.8297, 196.6447, 196.8078, 196.6541,
         196.7501, 196.9864, 196.7829],
        [197.0704, 196.4096, 196.6923, 197.9405, 196.7538, 196.9160, 196.7649,
         196.8594, 197.0969, 196.8941],
        [197.0775, 196

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.1362, 196.9876, 196.5862, 197.2967, 197.7660, 197.4337, 197.1268,
         197.6244, 196.7684, 198.1175],
        [196.8709, 196.7223, 196.3228, 197.0314, 197.4993, 197.1688, 196.8624,
         197.3595, 196.5029, 197.8525],
        [196.8652, 196.7162, 196.3147, 197.0265, 197.4951, 197.1638, 196.8564,
         197.3542, 196.4966, 197.8475],
        [197.2760, 197.1279, 196.7263, 197.4358, 197.9055, 197.5739, 197.2669,
         197.7636, 196.9086, 198.2572],
        [196.9247, 196.7745, 196.3736, 197.0838, 197.5546, 197.2206, 196.9169,
         197.4111, 196.5569, 197.9051],
        [197.0379, 196.8897, 196.4877, 197.1992, 197.6678, 197.3371, 197.0286,
         197.5270, 196.6696, 198.0203],
        [196.9101, 196.7600, 196.3590, 197.0698, 197.5402, 197.2065, 196.9020,
         197.3971, 196.5421, 197.8909],
        [197.0282, 196

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.2797, 196.9350, 196.4410, 195.9991, 196.9632, 196.5058, 196.9846,
         197.0035, 196.9585, 196.7250],
        [196.3490, 197.0056, 196.5100, 196.0683, 197.0315, 196.5758, 197.0544,
         197.0719, 197.0277, 196.7937],
        [196.4373, 197.0916, 196.5984, 196.1561, 197.1209, 196.6638, 197.1416,
         197.1603, 197.1154, 196.8825],
        [196.4607, 197.1163, 196.6211, 196.1783, 197.1425, 196.6879, 197.1653,
         197.1827, 197.1384, 196.9044],
        [196.1949, 196.8501, 196.3558, 195.9143, 196.8773, 196.4204, 196.8996,
         196.9179, 196.8731, 196.6391],
        [196.4349, 197.0921, 196.5962, 196.1544, 197.1181, 196.6626, 197.1406,
         197.1582, 197.1142, 196.8805],
        [196.6010, 197.2552, 196.7605, 196.3174, 197.2810, 196.8277, 197.3045,
         197.3207, 197.2767, 197.0429],
        [196.5544, 197.2075, 196.7141, 196.2710, 197.23

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.9710, 196.7371, 197.0701, 196.9404, 196.7807, 196.8356, 197.0946,
         197.4291, 196.9313, 197.0779],
        [197.2394, 197.0069, 197.3392, 197.2107, 197.0498, 197.1040, 197.3644,
         197.6976, 197.2017, 197.3468],
        [196.9294, 196.6941, 197.0282, 196.8982, 196.7386, 196.7930, 197.0529,
         197.3874, 196.8887, 197.0363],
        [197.1089, 196.8748, 197.2083, 197.0792, 196.9185, 196.9725, 197.2335,
         197.5671, 197.0698, 197.2162],
        [197.0296, 196.7955, 197.1287, 196.9997, 196.8393, 196.8934, 197.1538,
         197.4870, 196.9910, 197.1369],
        [196.9371, 196.7036, 197.0368, 196.9080, 196.7466, 196.8002, 197.0619,
         197.3945, 196.8995, 197.0447],
        [197.1581, 196.9223, 197.2569, 197.1277, 196.9672, 197.0209, 197.2824,
         197.6160, 197.1179, 197.2654],
        [197.0005, 196.7680, 197.1001, 196.9713, 196.81

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.3437, 196.5705, 196.8302, 197.0740, 196.8123, 196.8907, 197.0634,
         197.1727, 196.7575, 197.2331],
        [196.3409, 196.5692, 196.8284, 197.0717, 196.8104, 196.8877, 197.0615,
         197.1709, 196.7547, 197.2311],
        [196.2844, 196.5107, 196.7704, 197.0148, 196.7526, 196.8314, 197.0041,
         197.1129, 196.6982, 197.1736],
        [196.2721, 196.4957, 196.7594, 197.0026, 196.7406, 196.8202, 196.9928,
         197.0995, 196.6848, 197.1635],
        [196.1717, 196.3973, 196.6593, 196.9030, 196.6408, 196.7194, 196.8931,
         197.0006, 196.5850, 197.0628],
        [196.2601, 196.4866, 196.7486, 196.9915, 196.7296, 196.8077, 196.9821,
         197.0894, 196.6723, 197.1523],
        [196.5355, 196.7591, 197.0199, 197.2640, 197.0019, 197.0827, 197.2528,
         197.3615, 196.9482, 197.4240],
        [195.9917, 196

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.4370, 196.9665, 196.7103, 196.9387, 196.5349, 196.2124, 196.8488,
         197.7872, 196.8969, 196.7391],
        [197.7878, 197.3184, 197.0615, 197.2908, 196.8890, 196.5670, 197.2013,
         198.1391, 197.2487, 197.0927],
        [197.7448, 197.2757, 197.0186, 197.2476, 196.8463, 196.5246, 197.1586,
         198.0961, 197.2059, 197.0503],
        [197.2926, 196.8238, 196.5670, 196.7955, 196.3916, 196.0689, 196.7054,
         197.6439, 196.7541, 196.5963],
        [197.5186, 197.0485, 196.7929, 197.0207, 196.6188, 196.2965, 196.9317,
         197.8690, 196.9789, 196.8226],
        [197.5361, 197.0661, 196.8104, 197.0384, 196.6363, 196.3139, 196.9492,
         197.8867, 196.9965, 196.8402],
        [197.5005, 197.0328, 196.7750, 197.0046, 196.6014, 196.2791, 196.9144,
         197.8528, 196.9632, 196.8061],
        [197.4467, 196

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.7591, 196.6187, 196.4667, 196.7538, 196.4199, 196.4273, 196.8621,
         196.7812, 196.3627, 197.2559],
        [196.8879, 196.7486, 196.5957, 196.8829, 196.5494, 196.5568, 196.9913,
         196.9104, 196.4924, 197.3858],
        [197.0357, 196.8946, 196.7432, 197.0310, 196.6946, 196.7032, 197.1384,
         197.0575, 196.6390, 197.5308],
        [196.9355, 196.7970, 196.6435, 196.9308, 196.5978, 196.6049, 197.0394,
         196.9585, 196.5406, 197.4341],
        [197.0805, 196.9419, 196.7886, 197.0760, 196.7421, 196.7499, 197.1844,
         197.1035, 196.6856, 197.5788],
        [197.0958, 196.9569, 196.8038, 197.0913, 196.7569, 196.7648, 197.1995,
         197.1185, 196.7007, 197.5937],
        [196.6534, 196.5123, 196.3606, 196.6481, 196.3133, 196.3203, 196.7559,
         196.6747, 196.2564, 197.1488],
        [197.1625, 197

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.9093, 196.9123, 196.8108, 196.5861, 197.2680, 196.4911, 196.5146,
         197.6043, 196.8243, 196.8920],
        [196.7342, 196.7365, 196.6350, 196.4109, 197.0920, 196.3155, 196.3402,
         197.4286, 196.6491, 196.7189],
        [196.9513, 196.9550, 196.8525, 196.6276, 197.3098, 196.5335, 196.5572,
         197.6467, 196.8657, 196.9350],
        [196.8821, 196.8853, 196.7840, 196.5584, 197.2408, 196.4643, 196.4874,
         197.5775, 196.7974, 196.8643],
        [196.9812, 196.9851, 196.8820, 196.6566, 197.3386, 196.5643, 196.5888,
         197.6768, 196.8963, 196.9660],
        [197.0315, 197.0336, 196.9318, 196.7083, 197.3887, 196.6131, 196.6388,
         197.7255, 196.9472, 197.0165],
        [196.6646, 196.6674, 196.5662, 196.3407, 197.0226, 196.2467, 196.2704,
         197.3599, 196.5802, 196.6482],
        [197.1272, 197.1308, 197.0291, 196.8032, 197.48

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.6790, 196.0312, 196.4935, 196.2682, 197.0948, 196.6530, 196.9909,
         196.6259, 196.5895, 196.7942],
        [196.6596, 196.0095, 196.4724, 196.2485, 197.0738, 196.6312, 196.9697,
         196.6052, 196.5697, 196.7745],
        [196.6362, 195.9879, 196.4511, 196.2228, 197.0515, 196.6095, 196.9458,
         196.5813, 196.5448, 196.7483],
        [196.6390, 195.9892, 196.4525, 196.2256, 197.0532, 196.6106, 196.9477,
         196.5834, 196.5476, 196.7512],
        [196.2881, 195.6395, 196.1005, 195.8738, 196.7032, 196.2608, 196.5985,
         196.2327, 196.1970, 196.4011],
        [196.7164, 196.0679, 196.5313, 196.3036, 197.1316, 196.6895, 197.0261,
         196.6618, 196.6254, 196.8290],
        [196.5699, 195.9200, 196.3821, 196.1577, 196.9842, 196.5411, 196.8798,
         196.5152, 196.4799, 196.6842],
        [196.6217, 195.9711, 196.4334, 196.2097, 197.0354, 196.5921, 196.9311,
         196.5667, 196.5317, 196

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.6455, 196.6457, 197.5698, 196.8717, 196.8132, 196.7604, 196.7861,
         196.1240, 196.8177, 196.3083],
        [196.7177, 196.7144, 197.6376, 196.9397, 196.8776, 196.8292, 196.8530,
         196.1906, 196.8858, 196.3763],
        [196.6760, 196.6745, 197.5985, 196.9007, 196.8403, 196.7883, 196.8164,
         196.1512, 196.8485, 196.3370],
        [196.6404, 196.6360, 197.5606, 196.8637, 196.8017, 196.7519, 196.7778,
         196.1131, 196.8102, 196.2999],
        [196.5001, 196.4961, 197.4208, 196.7236, 196.6618, 196.6116, 196.6380,
         195.9734, 196.6704, 196.1600],
        [196.4907, 196.4873, 197.4114, 196.7135, 196.6515, 196.6013, 196.6289,
         195.9639, 196.6615, 196.1502],
        [196.4015, 196.3971, 197.3219, 196.6245, 196.5625, 196.5128, 196.5385,
         195.8747, 196.5711, 196.0613],
        [196.3440, 196.3430, 197.2677, 196.5696, 196.51

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.6372, 196.7947, 196.0327, 196.8656, 196.6386, 196.7785, 196.1871,
         196.0102, 197.0346, 196.7246],
        [197.6530, 196.8056, 196.0452, 196.8797, 196.6533, 196.7927, 196.2036,
         196.0279, 197.0456, 196.7398],
        [197.3950, 196.5517, 195.7896, 196.6235, 196.3958, 196.5352, 195.9440,
         195.7674, 196.7907, 196.4819],
        [197.5856, 196.7428, 195.9813, 196.8144, 196.5866, 196.7271, 196.1360,
         195.9590, 196.9828, 196.6740],
        [197.5663, 196.7200, 195.9589, 196.7927, 196.5667, 196.7063, 196.1171,
         195.9393, 196.9595, 196.6538],
        [197.6186, 196.7761, 196.0146, 196.8475, 196.6198, 196.7602, 196.1688,
         195.9923, 197.0162, 196.7066],
        [197.4957, 196.6531, 195.8911, 196.7245, 196.4969, 196.6365, 196.0452,
         195.8685, 196.8925, 196.5831],
        [197.6190, 196

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.9611, 196.6999, 196.5424, 196.8502, 196.7256, 196.7660, 196.5299,
         196.6279, 196.4356, 196.9585],
        [197.1642, 196.9043, 196.7471, 197.0542, 196.9287, 196.9702, 196.7343,
         196.8309, 196.6392, 197.1615],
        [197.0371, 196.7798, 196.6214, 196.9276, 196.8017, 196.8447, 196.6060,
         196.7051, 196.5144, 197.0356],
        [197.1165, 196.8589, 196.7000, 197.0065, 196.8810, 196.9235, 196.6861,
         196.7837, 196.5927, 197.1143],
        [197.0396, 196.7801, 196.6217, 196.9298, 196.8039, 196.8455, 196.6098,
         196.7058, 196.5144, 197.0368],
        [196.8923, 196.6329, 196.4753, 196.7837, 196.6563, 196.6987, 196.4628,
         196.5586, 196.3681, 196.8897],
        [196.8317, 196.5741, 196.4149, 196.7232, 196.5957, 196.6388, 196.4019,
         196.4983, 196.3083, 196.8296],
        [196.8381, 196

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.4738, 196.8645, 196.5292, 196.7162, 196.7247, 196.0097, 196.5243,
         197.3598, 195.9818, 196.4616],
        [197.1590, 196.5500, 196.2139, 196.4019, 196.4093, 195.6924, 196.2099,
         197.0447, 195.6664, 196.1461],
        [197.6455, 197.0362, 196.7004, 196.8881, 196.8955, 196.1817, 196.6952,
         197.5315, 196.1537, 196.6337],
        [197.2962, 196.6868, 196.3503, 196.5377, 196.5466, 195.8309, 196.3461,
         197.1806, 195.8038, 196.2850],
        [197.2881, 196.6787, 196.3437, 196.5311, 196.5392, 195.8229, 196.3393,
         197.1746, 195.7956, 196.2747],
        [197.3666, 196.7566, 196.4204, 196.6093, 196.6165, 195.9021, 196.4169,
         197.2518, 195.8738, 196.3551],
        [197.4809, 196.8712, 196.5365, 196.7236, 196.7323, 196.0173, 196.5319,
         197.3674, 195.9886, 196.4683],
        [197.3025, 196

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.4849, 196.6979, 196.2789, 196.5688, 197.2717, 196.6423, 196.0747,
         196.6521, 196.6207, 196.8209],
        [196.4174, 196.6292, 196.2107, 196.5008, 197.2034, 196.5738, 196.0056,
         196.5819, 196.5526, 196.7517],
        [196.4426, 196.6542, 196.2358, 196.5259, 197.2283, 196.5988, 196.0306,
         196.6067, 196.5776, 196.7766],
        [196.3580, 196.5656, 196.1482, 196.4382, 197.1387, 196.5129, 195.9435,
         196.5168, 196.4900, 196.6886],
        [196.3674, 196.5757, 196.1584, 196.4487, 197.1485, 196.5225, 195.9534,
         196.5280, 196.4995, 196.6993],
        [196.6802, 196.8909, 196.4731, 196.7634, 197.4643, 196.8367, 196.2696,
         196.8448, 196.8153, 197.0146],
        [196.4449, 196.6541, 196.2368, 196.5274, 197.2270, 196.6006, 196.0321,
         196.6075, 196.5778, 196.7781],
        [196.3396, 196

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.8135, 195.9908, 197.0222, 196.6627, 196.6722, 195.9453, 196.6657,
         196.6943, 196.1491, 196.6624],
        [195.7560, 195.9344, 196.9664, 196.6076, 196.6165, 195.8895, 196.6096,
         196.6380, 196.0935, 196.6066],
        [195.8527, 196.0317, 197.0616, 196.7049, 196.7118, 195.9841, 196.7055,
         196.7333, 196.1899, 196.7023],
        [195.7831, 195.9607, 196.9911, 196.6323, 196.6408, 195.9124, 196.6352,
         196.6635, 196.1184, 196.6319],
        [195.7744, 195.9541, 196.9872, 196.6268, 196.6363, 195.9088, 196.6295,
         196.6589, 196.1129, 196.6272],
        [195.8499, 196.0271, 197.0570, 196.6992, 196.7075, 195.9802, 196.7013,
         196.7292, 196.1851, 196.6977],
        [195.7789, 195.9582, 196.9909, 196.6305, 196.6400, 195.9124, 196.6335,
         196.6626, 196.1167, 196.6308],
        [196.1015, 196

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.1544, 195.9491, 196.4445, 196.6538, 196.6749, 195.7143, 196.5724,
         196.5318, 196.0232, 195.9962],
        [196.1554, 195.9505, 196.4459, 196.6556, 196.6739, 195.7158, 196.5723,
         196.5347, 196.0246, 195.9974],
        [196.0668, 195.8620, 196.3567, 196.5652, 196.5855, 195.6261, 196.4838,
         196.4449, 195.9359, 195.9076],
        [196.1901, 195.9869, 196.4794, 196.6873, 196.7092, 195.7489, 196.6084,
         196.5665, 196.0596, 196.0313],
        [195.9119, 195.7058, 196.2020, 196.4130, 196.4314, 195.4726, 196.3300,
         196.2907, 195.7817, 195.7532],
        [196.1765, 195.9732, 196.4660, 196.6749, 196.6959, 195.7360, 196.5953,
         196.5533, 196.0463, 196.0184],
        [196.1293, 195.9241, 196.4194, 196.6289, 196.6499, 195.6893, 196.5475,
         196.5068, 195.9983, 195.9711],
        [196.0409, 195

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.7629, 196.5012, 196.5577, 196.2484, 197.1170, 196.7294, 196.7444,
         196.6395, 196.3312, 196.4906],
        [195.6518, 196.3944, 196.4473, 196.1400, 197.0047, 196.6229, 196.6350,
         196.5288, 196.2197, 196.3775],
        [195.7995, 196.5414, 196.5957, 196.2873, 197.1518, 196.7687, 196.7826,
         196.6770, 196.3676, 196.5276],
        [195.4534, 196.1945, 196.2494, 195.9406, 196.8076, 196.4233, 196.4364,
         196.3312, 196.0211, 196.1813],
        [195.6875, 196.4299, 196.4833, 196.1755, 197.0398, 196.6574, 196.6707,
         196.5654, 196.2554, 196.4148],
        [195.7508, 196.4920, 196.5447, 196.2380, 197.1029, 196.7196, 196.7330,
         196.6282, 196.3190, 196.4760],
        [195.4168, 196.1584, 196.2121, 195.9043, 196.7709, 196.3876, 196.3998,
         196.2943, 195.9845, 196.1429],
        [195.6995, 196.4428, 196.4952, 196.1882, 197.05

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.6774, 196.3070, 195.9596, 196.1382, 196.4619, 196.4817, 196.3245,
         196.4069, 196.2156, 196.2987],
        [195.7511, 196.3785, 196.0328, 196.2123, 196.5371, 196.5523, 196.3967,
         196.4807, 196.2880, 196.3732],
        [195.5644, 196.1923, 195.8459, 196.0249, 196.3489, 196.3693, 196.2135,
         196.2941, 196.1032, 196.1852],
        [195.6561, 196.2831, 195.9372, 196.1173, 196.4426, 196.4569, 196.3015,
         196.3854, 196.1925, 196.2781],
        [195.4518, 196.0785, 195.7325, 195.9120, 196.2368, 196.2555, 196.1003,
         196.1813, 195.9901, 196.0728],
        [195.8929, 196.5217, 196.1756, 196.3545, 196.6782, 196.6956, 196.5393,
         196.6228, 196.4306, 196.5149],
        [195.6993, 196.3284, 195.9812, 196.1596, 196.4842, 196.5013, 196.3441,
         196.4283, 196.2368, 196.3213],
        [195.5676, 196.1942, 195.8489, 196.0282, 196.35

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.7042, 196.1652, 195.7502, 197.0759, 196.1849, 196.1300, 196.4096,
         196.5891, 196.2934, 197.1928],
        [195.7617, 196.2226, 195.8068, 197.1374, 196.2406, 196.1881, 196.4649,
         196.6463, 196.3511, 197.2509],
        [195.4485, 195.9096, 195.4947, 196.8236, 195.9262, 195.8729, 196.1508,
         196.3333, 196.0381, 196.9376],
        [195.6077, 196.0686, 195.6530, 196.9837, 196.0855, 196.0330, 196.3098,
         196.4922, 196.1971, 197.0968],
        [195.6897, 196.1508, 195.7348, 197.0650, 196.1691, 196.1161, 196.3932,
         196.5746, 196.2791, 197.1789],
        [195.7045, 196.1652, 195.7514, 197.0806, 196.1828, 196.1307, 196.4081,
         196.5897, 196.2945, 197.1937],
        [195.6249, 196.0856, 195.6713, 197.0011, 196.1028, 196.0506, 196.3280,
         196.5099, 196.2147, 197.1140],
        [195.6360, 196.0970, 195.6823, 197.0090, 196.11

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.4695, 196.4393, 195.9249, 196.6494, 196.5100, 196.1559, 197.4128,
         196.7040, 196.6327, 196.4934],
        [196.2079, 196.1792, 195.6650, 196.3910, 196.2505, 195.8970, 197.1531,
         196.4457, 196.3708, 196.2344],
        [196.2199, 196.1925, 195.6781, 196.4049, 196.2631, 195.9100, 197.1661,
         196.4588, 196.3828, 196.2487],
        [196.4625, 196.4327, 195.9182, 196.6432, 196.5036, 196.1495, 197.4063,
         196.6977, 196.6257, 196.4869],
        [196.1685, 196.1374, 195.6229, 196.3484, 196.2085, 195.8555, 197.1116,
         196.4030, 196.3326, 196.1925],
        [196.3050, 196.2747, 195.7604, 196.4850, 196.3463, 195.9921, 197.2486,
         196.5406, 196.4679, 196.3279],
        [196.2984, 196.2674, 195.7532, 196.4779, 196.3382, 195.9848, 197.2413,
         196.5325, 196.4621, 196.3222],
        [196.2003, 196

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.2053, 195.9228, 196.4493, 196.4264, 197.0919, 196.5924, 196.3705,
         196.0802, 195.9195, 196.2857],
        [197.0421, 195.7582, 196.2853, 196.2608, 196.9263, 196.4297, 196.2083,
         195.9173, 195.7574, 196.1220],
        [197.0277, 195.7442, 196.2708, 196.2481, 196.9127, 196.4149, 196.1940,
         195.9029, 195.7430, 196.1074],
        [196.9289, 195.6449, 196.1716, 196.1488, 196.8132, 196.3164, 196.0958,
         195.8042, 195.6442, 196.0085],
        [197.0341, 195.7512, 196.2784, 196.2561, 196.9213, 196.4217, 196.1996,
         195.9090, 195.7480, 196.1138],
        [197.3230, 196.0402, 196.5661, 196.5417, 197.2075, 196.7098, 196.4884,
         196.1982, 196.0386, 196.4039],
        [197.3425, 196.0600, 196.5865, 196.5621, 197.2283, 196.7294, 196.5071,
         196.2175, 196.0572, 196.4231],
        [197.1332, 195.8499, 196.3769, 196.3521, 197.01

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.5571, 197.0075, 197.3579, 196.3721, 196.3403, 196.1734, 196.5546,
         196.3383, 196.5412, 196.5126],
        [195.4379, 196.8876, 197.2366, 196.2496, 196.2227, 196.0541, 196.4325,
         196.2144, 196.4185, 196.3913],
        [195.4478, 196.8970, 197.2465, 196.2616, 196.2324, 196.0640, 196.4430,
         196.2263, 196.4289, 196.4018],
        [195.3961, 196.8464, 197.1948, 196.2077, 196.1810, 196.0127, 196.3909,
         196.1723, 196.3772, 196.3498],
        [195.3316, 196.7795, 197.1306, 196.1467, 196.1154, 195.9471, 196.3286,
         196.1110, 196.3128, 196.2858],
        [195.6158, 197.0658, 197.4146, 196.4276, 196.4008, 196.2319, 196.6098,
         196.3933, 196.5965, 196.5692],
        [195.6497, 197.0987, 197.4498, 196.4648, 196.4332, 196.2651, 196.6462,
         196.4313, 196.6320, 196.6043],
        [195.4544, 196.9050, 197.2545, 196.2671, 196.23

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[197.3022, 196.2213, 196.5931, 196.3191, 196.6366, 196.6155, 196.4406,
         196.5288, 196.1606, 197.2214],
        [196.8399, 195.7585, 196.1327, 195.8563, 196.1739, 196.1529, 195.9788,
         196.0663, 195.6954, 196.7586],
        [197.1863, 196.1048, 196.4775, 196.2037, 196.5209, 196.4996, 196.3244,
         196.4131, 196.0444, 197.1056],
        [197.0927, 196.0136, 196.3881, 196.1107, 196.4283, 196.4062, 196.2350,
         196.3203, 195.9506, 197.0127],
        [197.3112, 196.2306, 196.6032, 196.3295, 196.6469, 196.6244, 196.4508,
         196.5390, 196.1706, 197.2311],
        [196.9087, 195.8284, 196.2035, 195.9260, 196.2433, 196.2221, 196.0495,
         196.1353, 195.7647, 196.8279],
        [196.9471, 195.8665, 196.2403, 195.9635, 196.2812, 196.2602, 196.0868,
         196.1737, 195.8033, 196.8661],
        [197.1844, 196.1049, 196.4779, 196.2020, 196.51

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.9751, 196.4334, 196.4310, 196.4284, 196.0368, 196.4259, 196.2017,
         196.5295, 196.1510, 196.2985],
        [195.9057, 196.3647, 196.3626, 196.3588, 195.9670, 196.3566, 196.1324,
         196.4604, 196.0817, 196.2284],
        [195.7335, 196.1932, 196.1908, 196.1880, 195.7972, 196.1854, 195.9619,
         196.2900, 195.9097, 196.0584],
        [195.9096, 196.3685, 196.3662, 196.3628, 195.9712, 196.3602, 196.1364,
         196.4644, 196.0856, 196.2326],
        [195.6641, 196.1243, 196.1216, 196.1184, 195.7273, 196.1159, 195.8922,
         196.2206, 195.8402, 195.9883],
        [196.1683, 196.6253, 196.6226, 196.6190, 196.2287, 196.6159, 196.3942,
         196.7216, 196.3433, 196.4900],
        [195.8885, 196.3472, 196.3447, 196.3410, 195.9513, 196.3378, 196.1165,
         196.4442, 196.0641, 196.2119],
        [195.6269, 196

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.2486, 195.6727, 196.3457, 196.1709, 195.6700, 196.0112, 195.8504,
         196.3048, 195.4695, 196.0942],
        [196.1826, 195.6059, 196.2793, 196.1058, 195.6056, 195.9435, 195.7846,
         196.2379, 195.4026, 196.0279],
        [196.2065, 195.6299, 196.3034, 196.1291, 195.6280, 195.9679, 195.8079,
         196.2624, 195.4265, 196.0516],
        [196.1917, 195.6160, 196.2894, 196.1132, 195.6127, 195.9550, 195.7946,
         196.2479, 195.4118, 196.0370],
        [196.1166, 195.5402, 196.2139, 196.0380, 195.5376, 195.8794, 195.7184,
         196.1727, 195.3368, 195.9621],
        [196.5160, 195.9423, 196.6138, 196.4395, 195.9390, 196.2802, 196.1208,
         196.5721, 195.7376, 196.3621],
        [196.2711, 195.6973, 196.3695, 196.1931, 195.6948, 196.0366, 195.8770,
         196.3270, 195.4925, 196.1177],
        [196.3330, 195

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.7354, 196.0129, 196.1356, 196.2902, 195.4311, 196.3699, 195.7504,
         195.9711, 196.1772, 195.8474],
        [196.6441, 195.9202, 196.0426, 196.1995, 195.3396, 196.2774, 195.6573,
         195.8793, 196.0844, 195.7544],
        [196.9265, 196.2028, 196.3257, 196.4822, 195.6224, 196.5601, 195.9401,
         196.1623, 196.3674, 196.0369],
        [197.0659, 196.3427, 196.4655, 196.6211, 195.7619, 196.6995, 196.0802,
         196.3019, 196.5073, 196.1771],
        [196.5570, 195.8331, 195.9553, 196.1121, 195.2524, 196.1900, 195.5698,
         195.7917, 195.9967, 195.6672],
        [196.8466, 196.1243, 196.2478, 196.4023, 195.5428, 196.4820, 195.8616,
         196.0829, 196.2884, 195.9584],
        [196.8155, 196.0912, 196.2128, 196.3698, 195.5108, 196.4471, 195.8281,
         196.0499, 196.2553, 195.9258],
        [196.7560, 196

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.3929, 196.0316, 195.7135, 196.4046, 196.3213, 196.2698, 196.6474,
         196.2117, 197.3123, 196.2570],
        [196.1556, 195.7949, 195.4774, 196.1685, 196.0844, 196.0323, 196.4106,
         195.9748, 197.0753, 196.0205],
        [195.8972, 195.5366, 195.2190, 195.9107, 195.8261, 195.7732, 196.1529,
         195.7175, 196.8177, 195.7622],
        [196.0144, 195.6539, 195.3354, 196.0271, 195.9441, 195.8908, 196.2706,
         195.8357, 196.9356, 195.8792],
        [196.0695, 195.7071, 195.3875, 196.0818, 195.9963, 195.9433, 196.3228,
         195.8854, 196.9884, 195.9314],
        [196.0062, 195.6438, 195.3252, 196.0189, 195.9326, 195.8802, 196.2594,
         195.8221, 196.9246, 195.8686],
        [195.8077, 195.4467, 195.1288, 195.8213, 195.7359, 195.6828, 196.0629,
         195.6270, 196.7279, 195.6721],
        [196.3992, 196.0383, 195.7199, 196.4109, 196.32

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.8454, 195.9490, 196.1220, 196.1283, 195.8110, 195.6089, 196.3716,
         195.7588, 195.7059, 195.7326],
        [195.9155, 196.0171, 196.1901, 196.1976, 195.8803, 195.6779, 196.4419,
         195.8280, 195.7747, 195.8026],
        [195.9313, 196.0323, 196.2070, 196.2133, 195.8965, 195.6943, 196.4600,
         195.8442, 195.7909, 195.8200],
        [195.9261, 196.0295, 196.2013, 196.2086, 195.8911, 195.6887, 196.4507,
         195.8390, 195.7860, 195.8121],
        [195.8594, 195.9609, 196.1357, 196.1416, 195.8248, 195.6225, 196.3876,
         195.7722, 195.7196, 195.7477],
        [195.9975, 196.0999, 196.2731, 196.2793, 195.9627, 195.7600, 196.5236,
         195.9102, 195.8576, 195.8843],
        [195.7063, 195.8095, 195.9833, 195.9897, 195.6721, 195.4701, 196.2333,
         195.6196, 195.5672, 195.5941],
        [195.9813, 196

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.9792, 196.8965, 195.8714, 196.2926, 196.4561, 196.1927, 196.1435,
         196.2895, 195.9634, 195.9663],
        [196.0097, 196.9252, 195.9006, 196.3210, 196.4832, 196.2203, 196.1706,
         196.3176, 195.9913, 195.9935],
        [195.9930, 196.9079, 195.8831, 196.3032, 196.4652, 196.2029, 196.1534,
         196.3004, 195.9743, 195.9763],
        [195.6950, 196.6108, 195.5863, 196.0070, 196.1690, 195.9070, 195.8568,
         196.0038, 195.6767, 195.6797],
        [195.9521, 196.8684, 195.8440, 196.2651, 196.4280, 196.1645, 196.1142,
         196.2610, 195.9343, 195.9371],
        [195.9121, 196.8284, 195.8042, 196.2254, 196.3879, 196.1243, 196.0736,
         196.2208, 195.8936, 195.8963],
        [195.8339, 196.7476, 195.7233, 196.1433, 196.3047, 196.0432, 195.9921,
         196.1400, 195.8129, 195.8152],
        [195.8702, 196

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.9927, 196.7923, 195.9898, 197.0667, 195.9611, 195.8489, 196.4048,
         196.0503, 196.7719, 195.4686],
        [195.7813, 196.5862, 195.7794, 196.8573, 195.7545, 195.6378, 196.1968,
         195.8401, 196.5611, 195.2584],
        [195.7431, 196.5470, 195.7410, 196.8186, 195.7154, 195.5990, 196.1580,
         195.8018, 196.5232, 195.2202],
        [195.8660, 196.6674, 195.8632, 196.9393, 195.8354, 195.7213, 196.2786,
         195.9237, 196.6464, 195.3430],
        [195.9597, 196.7641, 195.9573, 197.0353, 195.9321, 195.8166, 196.3747,
         196.0179, 196.7386, 195.4362],
        [195.8785, 196.6829, 195.8760, 196.9522, 195.8503, 195.7353, 196.2931,
         195.9366, 196.6594, 195.3566],
        [195.9140, 196.7184, 195.9118, 196.9898, 195.8866, 195.7708, 196.3291,
         195.9724, 196.6931, 195.3906],
        [195.7095, 196

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.5457, 195.9274, 195.8522, 196.0791, 195.9888, 196.0331, 195.6231,
         195.9071, 196.9175, 195.3565],
        [196.5187, 195.9002, 195.8244, 196.0517, 195.9614, 196.0059, 195.5955,
         195.8792, 196.8902, 195.3290],
        [196.5359, 195.9162, 195.8417, 196.0691, 195.9788, 196.0228, 195.6127,
         195.8956, 196.9069, 195.3466],
        [196.3828, 195.7649, 195.6874, 195.9155, 195.8243, 195.8703, 195.4579,
         195.7428, 196.7545, 195.1908],
        [196.5916, 195.9728, 195.8948, 196.1239, 196.0327, 196.0786, 195.6671,
         195.9486, 196.9627, 195.4004],
        [196.6849, 196.0672, 195.9889, 196.2174, 196.1262, 196.1721, 195.7611,
         196.0431, 197.0566, 195.4939],
        [196.6233, 196.0062, 195.9269, 196.1555, 196.0644, 196.1107, 195.6993,
         195.9817, 196.9951, 195.4321],
        [196.4206, 195.8012, 195.7256, 195.9535, 195.86

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.0136, 195.6220, 195.8037, 195.3650, 195.9520, 195.8147, 196.0400,
         195.3197, 195.5461, 196.7257],
        [196.1563, 195.7654, 195.9460, 195.5095, 196.0953, 195.9593, 196.1824,
         195.4632, 195.6910, 196.8682],
        [196.0422, 195.6514, 195.8317, 195.3959, 195.9814, 195.8447, 196.0687,
         195.3495, 195.5763, 196.7544],
        [195.9438, 195.5542, 195.7357, 195.2964, 195.8835, 195.7465, 195.9713,
         195.2503, 195.4768, 196.6579],
        [196.0667, 195.6767, 195.8582, 195.4194, 196.0065, 195.8688, 196.0937,
         195.3744, 195.5994, 196.7805],
        [196.0109, 195.6232, 195.8027, 195.3677, 195.9527, 195.8158, 196.0388,
         195.3206, 195.5462, 196.7263],
        [196.0333, 195.6437, 195.8249, 195.3866, 195.9735, 195.8357, 196.0606,
         195.3414, 195.5662, 196.7474],
        [196.1016, 195

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.2551, 196.9311, 196.1929, 195.6720, 196.0863, 195.5951, 196.0931,
         195.8913, 195.3693, 196.0934],
        [196.1485, 196.8259, 196.0901, 195.5678, 195.9816, 195.4930, 195.9877,
         195.7874, 195.2658, 195.9897],
        [196.2712, 196.9485, 196.2122, 195.6899, 196.1035, 195.6156, 196.1102,
         195.9082, 195.3869, 196.1111],
        [196.2923, 196.9696, 196.2335, 195.7111, 196.1247, 195.6369, 196.1311,
         195.9300, 195.4089, 196.1326],
        [196.4278, 197.1039, 196.3667, 195.8461, 196.2590, 195.7701, 196.2662,
         196.0666, 195.5457, 196.2675],
        [196.2362, 196.9138, 196.1779, 195.6552, 196.0689, 195.5813, 196.0754,
         195.8735, 195.3522, 196.0765],
        [195.9962, 196.6733, 195.9364, 195.4144, 195.8292, 195.3382, 195.8351,
         195.6332, 195.1106, 195.8362],
        [196.1230, 196.8001, 196.0634, 195.5430, 195.95

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.9204, 196.0344, 195.6962, 195.7105, 196.0514, 195.9471, 195.1151,
         195.7771, 195.7489, 195.6122],
        [195.9911, 196.1076, 195.7699, 195.7836, 196.1237, 196.0193, 195.1873,
         195.8506, 195.8214, 195.6859],
        [195.8373, 195.9521, 195.6142, 195.6281, 195.9690, 195.8647, 195.0326,
         195.6945, 195.6671, 195.5296],
        [195.9489, 196.0637, 195.7253, 195.7385, 196.0807, 195.9758, 195.1430,
         195.8071, 195.7783, 195.6413],
        [196.0415, 196.1573, 195.8182, 195.8319, 196.1736, 196.0684, 195.2357,
         195.9009, 195.8697, 195.7355],
        [195.9401, 196.0542, 195.7174, 195.7318, 196.0715, 195.9677, 195.1363,
         195.7965, 195.7708, 195.6322],
        [195.7498, 195.8661, 195.5273, 195.5415, 195.8819, 195.7777, 194.9456,
         195.6081, 195.5787, 195.4435],
        [195.8023, 195.9187, 195.5810, 195.5941, 195.93

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.7473, 195.1113, 195.8364, 195.9817, 195.8441, 195.3107, 195.7866,
         196.4761, 195.1340, 196.1505],
        [196.7591, 195.1230, 195.8478, 195.9936, 195.8579, 195.3241, 195.8006,
         196.4864, 195.1461, 196.1620],
        [196.6586, 195.0203, 195.7457, 195.8909, 195.7551, 195.2223, 195.6968,
         196.3860, 195.0459, 196.0580],
        [196.9525, 195.3157, 196.0404, 196.1852, 196.0495, 195.5174, 195.9928,
         196.6801, 195.3415, 196.3539],
        [196.6232, 194.9850, 195.7124, 195.8564, 195.7182, 195.1855, 195.6601,
         196.3536, 195.0112, 196.0235],
        [196.6061, 194.9693, 195.6958, 195.8409, 195.7028, 195.1690, 195.6449,
         196.3356, 194.9926, 196.0087],
        [196.6501, 195.0121, 195.7361, 195.8817, 195.7458, 195.2134, 195.6871,
         196.3767, 195.0365, 196.0492],
        [196.6653, 195.0271, 195.7537, 195.8979, 195.76

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.4701, 195.1147, 195.6539, 196.8015, 195.6501, 195.5745, 195.9859,
         196.0660, 195.8200, 195.8134],
        [195.4342, 195.0792, 195.6203, 196.7668, 195.6153, 195.5385, 195.9515,
         196.0309, 195.7870, 195.7790],
        [195.5123, 195.1569, 195.6964, 196.8471, 195.6944, 195.6172, 196.0303,
         196.1096, 195.8646, 195.8593],
        [195.3817, 195.0264, 195.5676, 196.7135, 195.5624, 195.4861, 195.8984,
         195.9781, 195.7339, 195.7255],
        [195.3325, 194.9763, 195.5172, 196.6658, 195.5141, 195.4377, 195.8494,
         195.9293, 195.6844, 195.6772],
        [195.1568, 194.8004, 195.3450, 196.4910, 195.3398, 195.2624, 195.6752,
         195.7545, 195.5125, 195.5022],
        [195.6296, 195.2752, 195.8164, 196.9626, 195.8111, 195.7336, 196.1473,
         196.2259, 195.9828, 195.9746],
        [195.5044, 195

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.5576, 195.7313, 195.3640, 195.5974, 195.2735, 194.8949, 195.5887,
         195.7516, 195.7098, 195.3291],
        [195.5205, 195.6934, 195.3256, 195.5600, 195.2351, 194.8575, 195.5495,
         195.7139, 195.6733, 195.2893],
        [195.7436, 195.9194, 195.5521, 195.7847, 195.4623, 195.0837, 195.7758,
         195.9395, 195.8961, 195.5168],
        [195.8735, 196.0467, 195.6805, 195.9138, 195.5911, 195.2113, 195.9029,
         196.0681, 196.0252, 195.6445],
        [195.7593, 195.9333, 195.5659, 195.7997, 195.4763, 195.0981, 195.7884,
         195.9540, 195.9119, 195.5292],
        [195.5047, 195.6783, 195.3098, 195.5442, 195.2193, 194.8421, 195.5336,
         195.6981, 195.6579, 195.2733],
        [195.6234, 195.7944, 195.4270, 195.6623, 195.3371, 194.9583, 195.6498,
         195.8155, 195.7757, 195.3899],
        [195.9206, 196

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.6119, 196.3911, 196.0245, 195.9468, 196.7617, 195.8391, 195.9189,
         196.6045, 195.0899, 195.6123],
        [195.4943, 196.2678, 195.9031, 195.8226, 196.6386, 195.7184, 195.7996,
         196.4832, 194.9666, 195.4903],
        [195.6038, 196.3792, 196.0152, 195.9348, 196.7508, 195.8295, 195.9106,
         196.5943, 195.0783, 195.6021],
        [195.8224, 196.5980, 196.2323, 196.1531, 196.9691, 196.0475, 196.1283,
         196.8126, 195.2970, 195.8214],
        [195.4831, 196.2572, 195.8929, 195.8127, 196.6281, 195.7081, 195.7884,
         196.4723, 194.9562, 195.4793],
        [195.6061, 196.3793, 196.0146, 195.9344, 196.7500, 195.8303, 195.9107,
         196.5944, 195.0782, 195.6018],
        [195.5042, 196.2805, 195.9165, 195.8365, 196.6522, 195.7307, 195.8112,
         196.4954, 194.9796, 195.5026],
        [195.7239, 196.4974, 196.1303, 196.0511, 196.86

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.8028, 195.5457, 195.0505, 195.5916, 195.7538, 196.3739, 195.5946,
         195.6462, 195.7104, 196.6458],
        [195.8476, 195.5925, 195.0962, 195.6374, 195.8001, 196.4215, 195.6403,
         195.6924, 195.7544, 196.6913],
        [195.5108, 195.2557, 194.7608, 195.3011, 195.4626, 196.0842, 195.3042,
         195.3569, 195.4174, 196.3540],
        [195.7112, 195.4559, 194.9598, 195.5008, 195.6632, 196.2846, 195.5037,
         195.5558, 195.6177, 196.5545],
        [195.8806, 195.6244, 195.1298, 195.6707, 195.8321, 196.4535, 195.6733,
         195.7254, 195.7887, 196.7231],
        [195.6699, 195.4151, 194.9187, 195.4596, 195.6223, 196.2438, 195.4626,
         195.5150, 195.5759, 196.5136],
        [195.8019, 195.5450, 195.0510, 195.5916, 195.7532, 196.3734, 195.5949,
         195.6470, 195.7104, 196.6452],
        [195.7100, 195

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.7485, 195.5739, 196.5888, 195.6616, 195.3328, 196.4319, 195.7898,
         196.1005, 195.2579, 195.7523],
        [195.7315, 195.5587, 196.5731, 195.6455, 195.3165, 196.4161, 195.7748,
         196.0856, 195.2404, 195.7365],
        [195.7255, 195.5520, 196.5646, 195.6387, 195.3096, 196.4087, 195.7666,
         196.0768, 195.2342, 195.7294],
        [195.7793, 195.6075, 196.6204, 195.6933, 195.3644, 196.4641, 195.8227,
         196.1332, 195.2867, 195.7841],
        [195.8687, 195.6950, 196.7083, 195.7819, 195.4529, 196.5517, 195.9091,
         196.2200, 195.3781, 195.8723],
        [195.9653, 195.7937, 196.8036, 195.8785, 195.5496, 196.6486, 196.0057,
         196.3161, 195.4720, 195.9686],
        [195.9282, 195.7566, 196.7676, 195.8418, 195.5129, 196.6120, 195.9694,
         196.2802, 195.4358, 195.9322],
        [195.8212, 195

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.8048, 195.8291, 196.5301, 195.1776, 195.8517, 195.5098, 195.7064,
         195.5391, 195.0158, 196.1998],
        [195.7796, 195.8022, 196.5057, 195.1519, 195.8265, 195.4835, 195.6837,
         195.5140, 194.9906, 196.1745],
        [195.8944, 195.9177, 196.6215, 195.2682, 195.9423, 195.5987, 195.7993,
         195.6305, 195.1057, 196.2892],
        [195.6534, 195.6772, 196.3798, 195.0257, 195.7005, 195.3580, 195.5580,
         195.3876, 194.8643, 196.0481],
        [195.5033, 195.5294, 196.2291, 194.8753, 195.5501, 195.2094, 195.4061,
         195.2366, 194.7139, 195.8979],
        [195.8229, 195.8501, 196.5490, 195.1976, 195.8703, 195.5304, 195.7256,
         195.5581, 195.0339, 196.2176],
        [195.6041, 195.6318, 196.3297, 194.9775, 195.6511, 195.3116, 195.5064,
         195.3379, 194.8148, 195.9986],
        [195.8549, 195.8788, 196.5820, 195.2287, 195.90

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.1545, 196.2655, 195.3741, 195.6869, 196.4239, 194.9427, 195.3375,
         195.4393, 195.3848, 195.4874],
        [196.3247, 196.4362, 195.5448, 195.8575, 196.5943, 195.1145, 195.5079,
         195.6097, 195.5561, 195.6591],
        [196.2841, 196.3976, 195.5062, 195.8186, 196.5554, 195.0767, 195.4684,
         195.5719, 195.5173, 195.6218],
        [196.4109, 196.5238, 195.6325, 195.9456, 196.6812, 195.2049, 195.5961,
         195.6967, 195.6442, 195.7492],
        [196.2358, 196.3490, 195.4579, 195.7696, 196.5059, 195.0270, 195.4201,
         195.5229, 195.4687, 195.5727],
        [196.2526, 196.3646, 195.4730, 195.7860, 196.5231, 195.0432, 195.4361,
         195.5384, 195.4841, 195.5875],
        [196.1554, 196.2697, 195.3785, 195.6907, 196.4267, 194.9491, 195.3411,
         195.4440, 195.3892, 195.4946],
        [196.4602, 196

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.6649, 196.1308, 196.0136, 195.4254, 195.4889, 196.1502, 196.5003,
         195.6321, 195.2458, 195.3702],
        [195.6574, 196.1215, 196.0043, 195.4170, 195.4807, 196.1438, 196.4921,
         195.6251, 195.2388, 195.3623],
        [195.7409, 196.2061, 196.0884, 195.5020, 195.5645, 196.2276, 196.5767,
         195.7082, 195.3233, 195.4463],
        [195.6763, 196.1401, 196.0236, 195.4357, 195.4993, 196.1636, 196.5109,
         195.6449, 195.2568, 195.3815],
        [195.6541, 196.1193, 196.0029, 195.4149, 195.4765, 196.1426, 196.4901,
         195.6233, 195.2346, 195.3600],
        [195.6365, 196.1010, 195.9843, 195.3973, 195.4586, 196.1254, 196.4724,
         195.6056, 195.2178, 195.3422],
        [195.6469, 196.1108, 195.9941, 195.4064, 195.4698, 196.1340, 196.4816,
         195.6154, 195.2277, 195.3519],
        [195.6474, 196

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.8833, 195.2069, 194.8872, 195.4417, 195.7160, 195.7605, 194.7729,
         195.2258, 195.4278, 195.6237],
        [195.8685, 195.1921, 194.8723, 195.4267, 195.7011, 195.7458, 194.7582,
         195.2106, 195.4126, 195.6087],
        [195.6096, 194.9341, 194.6142, 195.1679, 195.4438, 195.4885, 194.5011,
         194.9532, 195.1547, 195.3510],
        [195.7149, 195.0406, 194.7198, 195.2742, 195.5489, 195.5939, 194.6061,
         195.0591, 195.2604, 195.4572],
        [196.0555, 195.3809, 195.0598, 195.6156, 195.8882, 195.9323, 194.9459,
         195.4002, 195.6009, 195.7969],
        [195.9807, 195.3059, 194.9849, 195.5404, 195.8136, 195.8578, 194.8710,
         195.3250, 195.5260, 195.7221],
        [195.8891, 195.2129, 194.8932, 195.4480, 195.7222, 195.7662, 194.7785,
         195.2321, 195.4342, 195.6303],
        [195.8644, 195

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.3829, 195.2877, 195.4527, 195.1884, 195.0123, 195.2749, 194.8520,
         195.3771, 195.5277, 195.1991],
        [195.4652, 195.3679, 195.5327, 195.2682, 195.0936, 195.3558, 194.9345,
         195.4572, 195.6081, 195.2798],
        [195.5224, 195.4243, 195.5889, 195.3247, 195.1506, 195.4126, 194.9921,
         195.5133, 195.6647, 195.3363],
        [195.6021, 195.5068, 195.6713, 195.4070, 195.2319, 195.4942, 195.0721,
         195.5958, 195.7468, 195.4183],
        [195.5243, 195.4297, 195.5953, 195.3266, 195.1534, 195.4162, 194.9929,
         195.5168, 195.6701, 195.3386],
        [195.3543, 195.2577, 195.4232, 195.1567, 194.9823, 195.2449, 194.8223,
         195.3465, 195.4979, 195.1685],
        [195.3676, 195.2725, 195.4383, 195.1707, 194.9963, 195.2592, 194.8356,
         195.3605, 195.5127, 195.1821],
        [195.4195, 195

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[196.3900, 195.3841, 195.4805, 195.2475, 195.5525, 195.4006, 195.5521,
         195.4163, 195.4234, 194.6813],
        [196.2747, 195.2677, 195.3663, 195.1330, 195.4381, 195.2857, 195.4367,
         195.3003, 195.3073, 194.5665],
        [196.4757, 195.4701, 195.5659, 195.3345, 195.6387, 195.4874, 195.6379,
         195.5023, 195.5095, 194.7674],
        [196.5237, 195.5147, 195.6148, 195.3764, 195.6831, 195.5311, 195.6847,
         195.5463, 195.5550, 194.8139],
        [196.4143, 195.4088, 195.5046, 195.2688, 195.5756, 195.4221, 195.5758,
         195.4404, 195.4474, 194.7049],
        [196.7160, 195.7079, 195.8063, 195.5709, 195.8768, 195.7246, 195.8769,
         195.7400, 195.7482, 195.0071],
        [196.4124, 195.4029, 195.5040, 195.2675, 195.5734, 195.4215, 195.5736,
         195.4352, 195.4436, 194.7034],
        [196.3666, 195

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.3488, 194.6730, 195.5074, 195.3441, 195.3690, 195.3012, 194.4975,
         195.2568, 196.1038, 195.1219],
        [195.4348, 194.7579, 195.5937, 195.4294, 195.4541, 195.3868, 194.5822,
         195.3422, 196.1884, 195.2070],
        [195.6065, 194.9299, 195.7660, 195.6011, 195.6259, 195.5585, 194.7547,
         195.5142, 196.3600, 195.3784],
        [195.5246, 194.8496, 195.6831, 195.5205, 195.5456, 195.4767, 194.6762,
         195.4332, 196.2809, 195.2977],
        [195.7492, 195.0759, 195.9087, 195.7447, 195.7707, 195.7018, 194.9012,
         195.6589, 196.5045, 195.5220],
        [195.6502, 194.9737, 195.8100, 195.6444, 195.6702, 195.6025, 194.7997,
         195.5582, 196.4038, 195.4216],
        [195.5537, 194.8802, 195.7126, 195.5497, 195.5752, 195.5063, 194.7055,
         195.4631, 196.3097, 195.3271],
        [195.5568, 194

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.5336, 195.6622, 195.5852, 195.3599, 194.8871, 195.4326, 195.6051,
         195.4221, 195.8601, 196.5159],
        [195.3020, 195.4300, 195.3516, 195.1263, 194.6551, 195.1989, 195.3710,
         195.1900, 195.6270, 196.2828],
        [195.2439, 195.3729, 195.2935, 195.0679, 194.5971, 195.1433, 195.3139,
         195.1335, 195.5710, 196.2258],
        [195.3091, 195.4363, 195.3581, 195.1335, 194.6612, 195.2068, 195.3794,
         195.1976, 195.6354, 196.2915],
        [195.3232, 195.4516, 195.3744, 195.1493, 194.6756, 195.2194, 195.3943,
         195.2086, 195.6471, 196.3034],
        [195.4202, 195.5475, 195.4692, 195.2445, 194.7732, 195.3184, 195.4897,
         195.3098, 195.7467, 196.4027],
        [195.5613, 195.6900, 195.6126, 195.3873, 194.9153, 195.4608, 195.6322,
         195.4508, 195.8882, 196.5437],
        [195.5683, 195.6965, 195.6191, 195.3939, 194.92

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.2248, 195.5482, 194.7921, 194.9715, 195.7070, 194.7449, 195.4442,
         195.5593, 194.6296, 195.1847],
        [195.2704, 195.5927, 194.8363, 195.0154, 195.7526, 194.7884, 195.4891,
         195.6039, 194.6745, 195.2292],
        [195.1296, 195.4514, 194.6934, 194.8728, 195.6097, 194.6464, 195.3478,
         195.4620, 194.5341, 195.0876],
        [195.2837, 195.6080, 194.8500, 195.0309, 195.7673, 194.8022, 195.5044,
         195.6181, 194.6879, 195.2430],
        [194.9953, 195.3180, 194.5617, 194.7395, 195.4763, 194.5139, 195.2139,
         195.3291, 194.4002, 194.9538],
        [195.2849, 195.6083, 194.8518, 195.0316, 195.7676, 194.8044, 195.5045,
         195.6193, 194.6894, 195.2445],
        [195.3194, 195.6399, 194.8861, 195.0637, 195.7986, 194.8403, 195.5356,
         195.6525, 194.7246, 195.2791],
        [194.9386, 195.2616, 194.5031, 194.6820, 195.41

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.4207, 194.7502, 195.4052, 196.0998, 195.5184, 195.4232, 195.5830,
         195.4210, 195.4176, 195.4457],
        [195.3297, 194.6586, 195.3126, 196.0100, 195.4266, 195.3326, 195.4936,
         195.3300, 195.3274, 195.3542],
        [195.5206, 194.8512, 195.5004, 196.2004, 195.6171, 195.5238, 195.6839,
         195.5202, 195.5173, 195.5452],
        [195.6373, 194.9680, 195.6160, 196.3163, 195.7331, 195.6393, 195.7991,
         195.6354, 195.6322, 195.6608],
        [195.1396, 194.4683, 195.1228, 195.8195, 195.2364, 195.1421, 195.3019,
         195.1395, 195.1368, 195.1635],
        [195.4149, 194.7446, 195.3957, 196.0954, 195.5114, 195.4182, 195.5792,
         195.4149, 195.4125, 195.4394],
        [195.5013, 194.8305, 195.4832, 196.1815, 195.5981, 195.5043, 195.6658,
         195.5015, 195.4987, 195.5259],
        [195.4185, 194.7483, 195.3999, 196.0989, 195.51

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.1133, 196.1629, 195.2755, 196.0050, 195.4008, 194.6866, 195.0791,
         195.4447, 194.8936, 194.9507],
        [195.1560, 196.2042, 195.3165, 196.0476, 195.4437, 194.7291, 195.1206,
         195.4887, 194.9361, 194.9929],
        [195.2830, 196.3313, 195.4441, 196.1742, 195.5700, 194.8574, 195.2484,
         195.6138, 195.0627, 195.1208],
        [195.2430, 196.2925, 195.4050, 196.1347, 195.5300, 194.8165, 195.2082,
         195.5739, 195.0234, 195.0807],
        [195.3597, 196.4085, 195.5206, 196.2512, 195.6463, 194.9357, 195.3242,
         195.6906, 195.1403, 195.1985],
        [195.1711, 196.2212, 195.3333, 196.0630, 195.4582, 194.7452, 195.1364,
         195.5022, 194.9518, 195.0090],
        [195.0766, 196.1250, 195.2368, 195.9682, 195.3644, 194.6511, 195.0415,
         195.4098, 194.8569, 194.9138],
        [195.2849, 196

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.1994, 194.4767, 194.8096, 195.9531, 194.2529, 194.8885, 194.9043,
         194.9730, 195.2176, 194.3252],
        [195.4407, 194.7201, 195.0522, 196.1958, 194.4949, 195.1319, 195.1459,
         195.2147, 195.4582, 194.5674],
        [195.1301, 194.4083, 194.7396, 195.8838, 194.1828, 194.8212, 194.8344,
         194.9045, 195.1475, 194.2558],
        [195.4270, 194.7041, 195.0356, 196.1788, 194.4786, 195.1150, 195.1308,
         195.1994, 195.4437, 194.5529],
        [195.4621, 194.7423, 195.0713, 196.2148, 194.5145, 195.1518, 195.1679,
         195.2346, 195.4812, 194.5883],
        [195.4844, 194.7630, 195.0939, 196.2369, 194.5372, 195.1721, 195.1899,
         195.2562, 195.5035, 194.6105],
        [195.3828, 194.6616, 194.9941, 196.1375, 194.4370, 195.0730, 195.0882,
         195.1564, 195.4010, 194.5093],
        [195.6485, 194

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.7038, 194.2202, 194.4499, 194.7540, 194.2427, 195.6899, 194.4206,
         194.5977, 195.0896, 195.1324],
        [195.0104, 194.5284, 194.7565, 195.0602, 194.5505, 195.9974, 194.7287,
         194.9046, 195.3966, 195.4393],
        [195.0941, 194.6108, 194.8387, 195.1455, 194.6352, 196.0801, 194.8127,
         194.9890, 195.4800, 195.5207],
        [195.2586, 194.7749, 195.0021, 195.3096, 194.8003, 196.2439, 194.9771,
         195.1537, 195.6431, 195.6832],
        [194.9792, 194.4951, 194.7245, 195.0295, 194.5208, 195.9650, 194.6970,
         194.8744, 195.3647, 195.4054],
        [195.0881, 194.6039, 194.8317, 195.1391, 194.6291, 196.0731, 194.8059,
         194.9830, 195.4724, 195.5128],
        [194.9234, 194.4402, 194.6692, 194.9733, 194.4637, 195.9097, 194.6411,
         194.8179, 195.3090, 195.3510],
        [195.0157, 194

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.2547, 195.0353, 195.2120, 195.5058, 195.1194, 195.4442, 195.3447,
         195.1191, 195.0329, 195.3414],
        [195.3319, 195.1116, 195.2900, 195.5833, 195.1958, 195.5223, 195.4220,
         195.1966, 195.1101, 195.4181],
        [195.4004, 195.1816, 195.3584, 195.6504, 195.2645, 195.5905, 195.4903,
         195.2652, 195.1784, 195.4869],
        [195.3042, 195.0847, 195.2617, 195.5554, 195.1691, 195.4939, 195.3941,
         195.1691, 195.0826, 195.3913],
        [195.4971, 195.2790, 195.4549, 195.7469, 195.3612, 195.6868, 195.5867,
         195.3623, 195.2751, 195.5838],
        [195.3221, 195.1035, 195.2813, 195.5715, 195.1874, 195.5131, 195.4117,
         195.1884, 195.1005, 195.4099],
        [195.2179, 194.9993, 195.1746, 195.4688, 195.0830, 195.4070, 195.3081,
         195.0820, 194.9961, 195.3048],
        [195.2009, 194.9849, 195.1581, 195.4491, 195.06

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.4912, 195.1241, 195.4717, 194.9976, 195.2136, 195.2282, 195.1124,
         194.8098, 195.3012, 195.2538],
        [195.2848, 194.9162, 195.2643, 194.7884, 195.0059, 195.0206, 194.9040,
         194.6020, 195.0931, 195.0458],
        [195.6278, 195.2620, 195.6093, 195.1355, 195.3514, 195.3650, 195.2505,
         194.9468, 195.4393, 195.3904],
        [195.3762, 195.0106, 195.3578, 194.8809, 195.0988, 195.1107, 194.9981,
         194.6938, 195.1877, 195.1366],
        [195.3176, 194.9519, 195.2986, 194.8237, 195.0398, 195.0548, 194.9410,
         194.6372, 195.1287, 195.0795],
        [195.5846, 195.2193, 195.5662, 195.0924, 195.3078, 195.3216, 195.2074,
         194.9039, 195.3962, 195.3476],
        [195.4554, 195.0900, 195.4368, 194.9622, 195.1788, 195.1929, 195.0799,
         194.7748, 195.2673, 195.2168],
        [195.4568, 195.0914, 195.4386, 194.9626, 195.1799, 195.1919, 195.0792,
         194.7749, 195.2686, 195

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.1737, 194.4218, 196.1032, 194.9055, 195.0951, 195.1873, 195.1982,
         194.4000, 194.7763, 194.3263],
        [195.1619, 194.4128, 196.0903, 194.8942, 195.0836, 195.1772, 195.1842,
         194.3871, 194.7661, 194.3146],
        [195.1087, 194.3603, 196.0374, 194.8413, 195.0302, 195.1245, 195.1311,
         194.3344, 194.7139, 194.2624],
        [195.1822, 194.4302, 196.1107, 194.9137, 195.1041, 195.1962, 195.2055,
         194.4074, 194.7841, 194.3336],
        [195.1165, 194.3660, 196.0459, 194.8481, 195.0369, 195.1310, 195.1402,
         194.3419, 194.7201, 194.2695],
        [195.2307, 194.4826, 196.1591, 194.9637, 195.1532, 195.2464, 195.2527,
         194.4566, 194.8358, 194.3839],
        [195.1946, 194.4469, 196.1230, 194.9276, 195.1167, 195.2108, 195.2165,
         194.4204, 194.8001, 194.3482],
        [195.1952, 194.4426, 196.1229, 194.9268, 195.11

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.2793, 195.2534, 195.4083, 195.1763, 195.2362, 194.4556, 194.8332,
         195.4190, 195.8127, 195.3228],
        [195.2156, 195.1886, 195.3469, 195.1124, 195.1744, 194.3928, 194.7726,
         195.3578, 195.7493, 195.2619],
        [195.2040, 195.1785, 195.3351, 195.0998, 195.1624, 194.3820, 194.7597,
         195.3458, 195.7370, 195.2509],
        [195.1134, 195.0871, 195.2444, 195.0095, 195.0717, 194.2904, 194.6691,
         195.2551, 195.6462, 195.1596],
        [195.1860, 195.1602, 195.3143, 195.0820, 195.1429, 194.3621, 194.7389,
         195.3254, 195.7180, 195.2297],
        [195.2162, 195.1904, 195.3435, 195.1113, 195.1740, 194.3928, 194.7689,
         195.3558, 195.7463, 195.2609],
        [195.2174, 195.1911, 195.3502, 195.1149, 195.1757, 194.3949, 194.7750,
         195.3599, 195.7529, 195.2637],
        [194.8732, 194

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.8772, 195.0010, 194.4500, 195.1968, 195.7806, 194.3516, 194.6516,
         195.0040, 195.0545, 195.1237],
        [194.7650, 194.8890, 194.3368, 195.0850, 195.6685, 194.2395, 194.5397,
         194.8920, 194.9426, 195.0112],
        [194.7865, 194.9107, 194.3584, 195.1071, 195.6891, 194.2602, 194.5616,
         194.9135, 194.9638, 195.0315],
        [194.7992, 194.9226, 194.3674, 195.1190, 195.7004, 194.2719, 194.5724,
         194.9258, 194.9772, 195.0420],
        [194.7244, 194.8481, 194.2958, 195.0440, 195.6280, 194.1991, 194.4987,
         194.8512, 194.9021, 194.9710],
        [194.8437, 194.9686, 194.4144, 195.1650, 195.7463, 194.3167, 194.6186,
         194.9711, 195.0218, 195.0870],
        [194.9014, 195.0263, 194.4728, 195.2222, 195.8054, 194.3752, 194.6758,
         195.0289, 195.0802, 195.1460],
        [194.8569, 194

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.9264, 194.6830, 195.4149, 195.0214, 195.1629, 194.7417, 195.0674,
         194.7599, 194.6123, 194.8300],
        [194.7585, 194.5110, 195.2446, 194.8512, 194.9917, 194.5697, 194.8963,
         194.5920, 194.4410, 194.6591],
        [195.0735, 194.8294, 195.5606, 195.1673, 195.3091, 194.8855, 195.2128,
         194.9084, 194.7590, 194.9767],
        [195.1260, 194.8813, 195.6157, 195.2216, 195.3614, 194.9413, 195.2667,
         194.9590, 194.8129, 195.0304],
        [194.7216, 194.4757, 195.2090, 194.8156, 194.9561, 194.5342, 194.8609,
         194.5543, 194.4046, 194.6241],
        [194.9745, 194.7269, 195.4617, 195.0678, 195.2076, 194.7862, 195.1126,
         194.8081, 194.6586, 194.8760],
        [195.0007, 194.7542, 195.4890, 195.0951, 195.2348, 194.8140, 195.1399,
         194.8338, 194.6858, 194.9034],
        [195.1099, 194

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.1274, 195.2095, 195.3016, 194.7529, 194.9152, 194.6967, 195.0087,
         195.6694, 195.1116, 194.3642],
        [194.8521, 194.9359, 195.0270, 194.4764, 194.6410, 194.4216, 194.7328,
         195.3960, 194.8375, 194.0869],
        [194.7467, 194.8312, 194.9218, 194.3705, 194.5355, 194.3165, 194.6275,
         195.2914, 194.7325, 193.9804],
        [195.0910, 195.1750, 195.2654, 194.7169, 194.8809, 194.6604, 194.9725,
         195.6357, 195.0770, 194.3268],
        [195.0362, 195.1191, 195.2109, 194.6613, 194.8247, 194.6055, 194.9166,
         195.5788, 195.0207, 194.2729],
        [195.0933, 195.1792, 195.2686, 194.7194, 194.8833, 194.6639, 194.9754,
         195.6391, 195.0797, 194.3286],
        [194.9796, 195.0639, 195.1551, 194.6046, 194.7679, 194.5499, 194.8605,
         195.5230, 194.9644, 194.2154],
        [195.0160, 195

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.8145, 195.0629, 195.2120, 195.2162, 195.1214, 194.5000, 195.2562,
         194.8020, 195.4595, 195.0759],
        [195.9784, 195.2271, 195.3761, 195.3800, 195.2856, 194.6653, 195.4199,
         194.9657, 195.6219, 195.2396],
        [195.9308, 195.1814, 195.3281, 195.3321, 195.2385, 194.6190, 195.3728,
         194.9188, 195.5735, 195.1937],
        [195.6752, 194.9244, 195.0737, 195.0771, 194.9834, 194.3612, 195.1190,
         194.6629, 195.3176, 194.9380],
        [195.7183, 194.9664, 195.1167, 195.1201, 195.0260, 194.4034, 195.1620,
         194.7063, 195.3617, 194.9802],
        [195.7428, 194.9920, 195.1402, 195.1446, 195.0497, 194.4286, 195.1845,
         194.7303, 195.3877, 195.0049],
        [195.8561, 195.1064, 195.2536, 195.2575, 195.1638, 194.5436, 195.2984,
         194.8441, 195.4988, 195.1190],
        [195.7970, 195

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.2663, 194.6899, 194.7798, 194.6133, 194.7690, 195.0586, 194.7961,
         195.9368, 195.8572, 195.1056],
        [195.3031, 194.7285, 194.8194, 194.6547, 194.8091, 195.0962, 194.8343,
         195.9749, 195.8979, 195.1443],
        [195.3432, 194.7683, 194.8589, 194.6942, 194.8491, 195.1359, 194.8744,
         196.0149, 195.9375, 195.1837],
        [195.2474, 194.6731, 194.7638, 194.5995, 194.7543, 195.0406, 194.7787,
         195.9195, 195.8428, 195.0887],
        [195.4927, 194.9168, 195.0076, 194.8423, 194.9969, 195.2847, 195.0237,
         196.1636, 196.0851, 195.3323],
        [195.3102, 194.7333, 194.8245, 194.6577, 194.8132, 195.1024, 194.8398,
         195.9808, 195.9009, 195.1494],
        [195.3602, 194.7860, 194.8739, 194.7091, 194.8632, 195.1527, 194.8919,
         196.0309, 195.9536, 195.2009],
        [195.1593, 194

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.9765, 195.5119, 194.9552, 195.0499, 195.0306, 194.8281, 194.9710,
         194.6533, 195.0028, 194.2189],
        [194.9021, 195.4372, 194.8801, 194.9748, 194.9560, 194.7536, 194.8953,
         194.5780, 194.9279, 194.1440],
        [194.9583, 195.4929, 194.9345, 195.0288, 195.0091, 194.8066, 194.9477,
         194.6309, 194.9831, 194.1992],
        [194.9147, 195.4498, 194.8918, 194.9858, 194.9651, 194.7632, 194.9053,
         194.5883, 194.9401, 194.1561],
        [195.0154, 195.5510, 194.9934, 195.0871, 195.0654, 194.8650, 195.0074,
         194.6913, 195.0412, 194.2574],
        [195.1880, 195.7232, 195.1660, 195.2604, 195.2400, 195.0376, 195.1808,
         194.8640, 195.2140, 194.4303],
        [195.0127, 195.5481, 194.9913, 195.0858, 195.0660, 194.8641, 195.0067,
         194.6896, 195.0389, 194.2551],
        [194.9624, 195

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.1862, 195.1902, 195.9434, 194.5229, 195.1687, 195.5366, 195.0400,
         194.9931, 194.9427, 194.8662],
        [194.9319, 194.9353, 195.6871, 194.2655, 194.9142, 195.2822, 194.7854,
         194.7368, 194.6872, 194.6106],
        [194.9841, 194.9874, 195.7396, 194.3175, 194.9672, 195.3341, 194.8373,
         194.7883, 194.7397, 194.6623],
        [194.8434, 194.8468, 195.5986, 194.1760, 194.8272, 195.1933, 194.6969,
         194.6468, 194.5986, 194.5217],
        [195.0329, 195.0366, 195.7890, 194.3682, 195.0145, 195.3833, 194.8863,
         194.8393, 194.7887, 194.7121],
        [195.0044, 195.0084, 195.7622, 194.3400, 194.9894, 195.3539, 194.8575,
         194.8089, 194.7612, 194.6831],
        [194.8929, 194.8970, 195.6506, 194.2288, 194.8773, 195.2424, 194.7463,
         194.6981, 194.6492, 194.5721],
        [194.8844, 194

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.5637, 194.5540, 195.2372, 195.1292, 194.9144, 195.3539, 195.4864,
         195.7210, 194.5992, 195.5638],
        [194.9011, 194.8932, 195.5732, 195.4683, 195.2517, 195.6891, 195.8234,
         196.0584, 194.9364, 195.8994],
        [194.6631, 194.6570, 195.3362, 195.2332, 195.0149, 195.4532, 195.5872,
         195.8218, 194.6994, 195.6643],
        [194.3687, 194.3601, 195.0417, 194.9360, 194.7197, 195.1585, 195.2915,
         195.5259, 194.4040, 195.3689],
        [194.8177, 194.8073, 195.4903, 195.3824, 195.1684, 195.6060, 195.7395,
         195.9741, 194.8528, 195.8159],
        [194.5796, 194.5721, 195.2519, 195.1484, 194.9312, 195.3683, 195.5028,
         195.7369, 194.6151, 195.5795],
        [194.7302, 194.7230, 195.4023, 195.2986, 195.0811, 195.5182, 195.6528,
         195.8875, 194.7655, 195.7290],
        [194.5697, 194.5614, 195.2414, 195.1373, 194.92

       grad_fn=<CdistBackward0>)
tensor([4, 4, 0, 4, 4, 0, 4, 0, 0, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.8928, 194.2216, 194.5551, 194.5049, 194.9153, 194.8913, 195.7219,
         194.1754, 194.6140, 194.6541],
        [194.8075, 194.1359, 194.4717, 194.4208, 194.8303, 194.8066, 195.6372,
         194.0886, 194.5285, 194.5694],
        [195.2054, 194.5354, 194.8706, 194.8201, 195.2291, 195.2045, 196.0353,
         194.4892, 194.9282, 194.9672],
        [194.6533, 193.9830, 194.3157, 194.2660, 194.6757, 194.6518, 195.4814,
         193.9349, 194.3733, 194.4136],
        [195.1227, 194.4529, 194.7847, 194.7353, 195.1458, 195.1209, 195.9510,
         194.4073, 194.8444, 194.8828],
        [194.8965, 194.2253, 194.5593, 194.5089, 194.9191, 194.8951, 195.7255,
         194.1789, 194.6177, 194.6577],
        [194.8747, 194.2052, 194.5388, 194.4888, 194.8980, 194.8736, 195.7031,
         194.1574, 194.5958, 194.6350],
        [194.9635, 194

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.4734, 195.5996, 195.1673, 194.8565, 194.6799, 194.5918, 194.7869,
         194.7734, 194.6185, 194.4377],
        [194.6267, 195.7527, 195.3194, 195.0073, 194.8328, 194.7418, 194.9413,
         194.9285, 194.7717, 194.5903],
        [194.1854, 195.3138, 194.8812, 194.5692, 194.3938, 194.3043, 194.5025,
         194.4884, 194.3320, 194.1505],
        [194.4733, 195.5988, 195.1681, 194.8533, 194.6795, 194.5912, 194.7880,
         194.7752, 194.6181, 194.4388],
        [194.3048, 195.4320, 194.9996, 194.6879, 194.5121, 194.4235, 194.6203,
         194.6065, 194.4505, 194.2696],
        [194.5548, 195.6803, 195.2509, 194.9357, 194.7612, 194.6736, 194.8678,
         194.8548, 194.6999, 194.5209],
        [194.3348, 195.4615, 195.0321, 194.7175, 194.5423, 194.4553, 194.6490,
         194.6353, 194.4807, 194.3013],
        [194.7126, 195.8374, 195.4059, 195.0932, 194.91

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.0285, 194.9189, 195.1325, 194.7561, 194.4815, 194.2285, 194.1803,
         194.5086, 194.7234, 194.6473],
        [194.1903, 195.0822, 195.2955, 194.9177, 194.6431, 194.3918, 194.3436,
         194.6716, 194.8855, 194.8093],
        [193.7424, 194.6336, 194.8492, 194.4708, 194.1960, 193.9427, 193.8927,
         194.2242, 194.4357, 194.3621],
        [193.8912, 194.7829, 195.0001, 194.6172, 194.3441, 194.0931, 194.0405,
         194.3743, 194.5828, 194.5106],
        [194.0860, 194.9776, 195.1943, 194.8120, 194.5388, 194.2881, 194.2368,
         194.5690, 194.7786, 194.7051],
        [193.9404, 194.8316, 195.0460, 194.6685, 194.3937, 194.1409, 194.0920,
         194.4217, 194.6348, 194.5596],
        [194.0319, 194.9234, 195.1399, 194.7584, 194.4849, 194.2336, 194.1827,
         194.5146, 194.7247, 194.6510],
        [193.7207, 194.6122, 194.8273, 194.4495, 194.17

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.4488, 194.0274, 195.4476, 194.5570, 193.8671, 193.8295, 194.5226,
         194.0577, 194.7894, 194.5432],
        [194.4635, 194.0410, 195.4618, 194.5714, 193.8820, 193.8444, 194.5366,
         194.0716, 194.8037, 194.5564],
        [194.6046, 194.1825, 195.6027, 194.7118, 194.0230, 193.9861, 194.6786,
         194.2121, 194.9446, 194.6977],
        [194.5065, 194.0847, 195.5060, 194.6163, 193.9243, 193.8873, 194.5789,
         194.1149, 194.8475, 194.5997],
        [194.3629, 193.9404, 195.3616, 194.4714, 193.7811, 193.7435, 194.4352,
         193.9711, 194.7034, 194.4557],
        [194.6157, 194.1951, 195.6137, 194.7217, 194.0342, 193.9970, 194.6917,
         194.2240, 194.9554, 194.7112],
        [194.5375, 194.1153, 195.5362, 194.6461, 193.9562, 193.9185, 194.6108,
         194.1460, 194.8782, 194.6309],
        [194.6976, 194

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.8104, 194.6314, 194.9442, 195.0411, 194.5263, 194.4872, 194.6177,
         195.5355, 194.6929, 194.8000],
        [194.5631, 194.3840, 194.6974, 194.7939, 194.2794, 194.2380, 194.3690,
         195.2880, 194.4457, 194.5524],
        [194.4581, 194.2790, 194.5924, 194.6891, 194.1744, 194.1328, 194.2644,
         195.1831, 194.3410, 194.4482],
        [194.7337, 194.5540, 194.8667, 194.9642, 194.4491, 194.4089, 194.5396,
         195.4580, 194.6156, 194.7218],
        [194.7545, 194.5748, 194.8873, 194.9851, 194.4699, 194.4301, 194.5607,
         195.4788, 194.6363, 194.7431],
        [194.6968, 194.5178, 194.8313, 194.9278, 194.4128, 194.3730, 194.5045,
         195.4220, 194.5794, 194.6861],
        [194.7941, 194.6146, 194.9273, 195.0245, 194.5097, 194.4697, 194.6002,
         195.5186, 194.6761, 194.7823],
        [194.7717, 194

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.0042, 195.1522, 194.5177, 194.3607, 193.9550, 194.4226, 194.6940,
         194.5811, 194.6712, 194.5760],
        [193.9983, 195.1484, 194.5082, 194.3536, 193.9456, 194.4155, 194.6872,
         194.5734, 194.6634, 194.5679],
        [194.3327, 195.4816, 194.8427, 194.6887, 194.2820, 194.7491, 195.0211,
         194.9078, 194.9973, 194.9020],
        [194.2071, 195.3537, 194.7195, 194.5641, 194.1576, 194.6248, 194.8958,
         194.7835, 194.8728, 194.7775],
        [194.0968, 195.2460, 194.6079, 194.4523, 194.0457, 194.5140, 194.7858,
         194.6723, 194.7621, 194.6664],
        [193.9253, 195.0750, 194.4373, 194.2814, 193.8746, 194.3434, 194.6151,
         194.5016, 194.5921, 194.4969],
        [194.2549, 195.4033, 194.7677, 194.6109, 194.2064, 194.6721, 194.9445,
         194.8312, 194.9211, 194.8258],
        [194.0616, 195

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.4843, 194.6393, 194.2740, 194.7412, 194.8183, 193.9940, 194.6594,
         194.3228, 194.5194, 195.2528],
        [194.6027, 194.7518, 194.3889, 194.8544, 194.9336, 194.1092, 194.7744,
         194.4362, 194.6335, 195.3686],
        [194.2563, 194.4093, 194.0448, 194.5119, 194.5879, 193.7635, 194.4301,
         194.0927, 194.2901, 195.0238],
        [194.5520, 194.7062, 194.3421, 194.8081, 194.8846, 194.0614, 194.7257,
         194.3887, 194.5876, 195.3202],
        [194.5825, 194.7350, 194.3706, 194.8368, 194.9159, 194.0909, 194.7560,
         194.4184, 194.6154, 195.3499],
        [194.4766, 194.6273, 194.2638, 194.7299, 194.8075, 193.9835, 194.6489,
         194.3112, 194.5088, 195.2431],
        [194.5001, 194.6517, 194.2875, 194.7538, 194.8328, 194.0077, 194.6733,
         194.3354, 194.5322, 195.2671],
        [194.6446, 194.8001, 194.4356, 194.9020, 194.97

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.8442, 194.3849, 194.6760, 194.4010, 194.3137, 194.5773, 194.5378,
         194.7109, 193.7710, 194.6209],
        [193.9418, 194.4839, 194.7754, 194.4985, 194.4090, 194.6775, 194.6326,
         194.8076, 193.8709, 194.7180],
        [193.7584, 194.3018, 194.5936, 194.3153, 194.2258, 194.4942, 194.4502,
         194.6258, 193.6877, 194.5352],
        [193.9592, 194.4990, 194.7899, 194.5161, 194.4285, 194.6925, 194.6518,
         194.8249, 193.8862, 194.7355],
        [193.8298, 194.3724, 194.6638, 194.3863, 194.2980, 194.5659, 194.5215,
         194.6965, 193.7592, 194.6066],
        [193.9584, 194.5014, 194.7927, 194.5153, 194.4266, 194.6937, 194.6509,
         194.8252, 193.8868, 194.7351],
        [193.7657, 194.3089, 194.6004, 194.3224, 194.2325, 194.5013, 194.4570,
         194.6326, 193.6948, 194.5421],
        [193.7767, 194

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.4089, 193.7103, 195.3774, 195.3000, 194.5742, 194.3282, 194.0404,
         194.2166, 194.5021, 194.5505],
        [194.5141, 193.8144, 195.4806, 195.4034, 194.6781, 194.4311, 194.1444,
         194.3180, 194.6057, 194.6547],
        [194.4615, 193.7621, 195.4279, 195.3516, 194.6255, 194.3791, 194.0920,
         194.2662, 194.5528, 194.6020],
        [194.5348, 193.8332, 195.5032, 195.4241, 194.6987, 194.4527, 194.1657,
         194.3400, 194.6280, 194.6759],
        [194.7090, 194.0095, 195.6756, 195.5977, 194.8730, 194.6256, 194.3393,
         194.5118, 194.8006, 194.8496],
        [194.4633, 193.7616, 195.4284, 195.3525, 194.6258, 194.3801, 194.0933,
         194.2661, 194.5535, 194.6029],
        [194.5504, 193.8491, 195.5190, 195.4399, 194.7145, 194.4686, 194.1815,
         194.3560, 194.6437, 194.6916],
        [194.7908, 194

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.9419, 194.3508, 193.9138, 193.4848, 194.1870, 193.8713, 194.9202,
         194.1744, 194.3235, 194.1059],
        [194.4106, 194.8203, 194.3819, 193.9570, 194.6552, 194.3429, 195.3891,
         194.6438, 194.7916, 194.5726],
        [194.3082, 194.7159, 194.2779, 193.8536, 194.5538, 194.2378, 195.2850,
         194.5379, 194.6870, 194.4691],
        [194.2163, 194.6254, 194.1871, 193.7619, 194.4616, 194.1471, 195.1943,
         194.4484, 194.5969, 194.3783],
        [194.1409, 194.5488, 194.1109, 193.6849, 194.3865, 194.0697, 195.1181,
         194.3701, 194.5209, 194.3021],
        [194.2498, 194.6591, 194.2209, 193.7944, 194.4945, 194.1807, 195.2282,
         194.4804, 194.6319, 194.4111],
        [194.3255, 194.7333, 194.2957, 193.8705, 194.5710, 194.2552, 195.3026,
         194.5550, 194.7047, 194.4867],
        [194.4883, 194

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.7611, 194.4839, 194.7209, 194.3089, 193.7852, 194.4336, 194.6422,
         194.5710, 194.0579, 194.0839],
        [194.5877, 194.3103, 194.5475, 194.1351, 193.6131, 194.2602, 194.4698,
         194.3987, 193.8848, 193.9115],
        [194.5824, 194.3057, 194.5421, 194.1296, 193.6086, 194.2550, 194.4637,
         194.3940, 193.8797, 193.9054],
        [194.7011, 194.4247, 194.6618, 194.2492, 193.7276, 194.3744, 194.5833,
         194.5128, 193.9995, 194.0249],
        [194.6693, 194.3917, 194.6269, 194.2154, 193.6924, 194.3411, 194.5498,
         194.4800, 193.9649, 193.9916],
        [194.6814, 194.4042, 194.6414, 194.2290, 193.7057, 194.3536, 194.5620,
         194.4908, 193.9779, 194.0039],
        [194.6800, 194.4028, 194.6407, 194.2280, 193.7054, 194.3528, 194.5623,
         194.4905, 193.9775, 194.0039],
        [194.5427, 194.2650, 194.5027, 194.0900, 193.56

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.5019, 193.9940, 194.3710, 193.9312, 194.8137, 193.4887, 193.6924,
         194.3308, 194.2973, 194.9229],
        [193.8359, 194.3294, 194.7033, 194.2671, 195.1496, 193.8275, 194.0271,
         194.6653, 194.6331, 195.2580],
        [193.8710, 194.3617, 194.7388, 194.3040, 195.1816, 193.8606, 194.0618,
         194.7006, 194.6665, 195.2925],
        [193.7394, 194.2329, 194.6073, 194.1716, 195.0525, 193.7311, 193.9314,
         194.5694, 194.5366, 195.1616],
        [193.6591, 194.1528, 194.5273, 194.0895, 194.9727, 193.6494, 193.8504,
         194.4885, 194.4560, 195.0811],
        [193.7212, 194.2135, 194.5892, 194.1530, 195.0331, 193.7114, 193.9127,
         194.5510, 194.5180, 195.1428],
        [193.8371, 194.3284, 194.7044, 194.2668, 195.1494, 193.8257, 194.0268,
         194.6656, 194.6335, 195.2577],
        [193.6876, 194.1773, 194.5564, 194.1187, 194.99

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.4737, 194.5495, 194.6302, 194.3980, 194.4699, 194.2572, 194.5436,
         194.2258, 194.4997, 194.5163],
        [194.4088, 194.4818, 194.5606, 194.3305, 194.4005, 194.1900, 194.4772,
         194.1572, 194.4336, 194.4475],
        [194.4643, 194.5376, 194.6191, 194.3861, 194.4570, 194.2458, 194.5327,
         194.2132, 194.4889, 194.5035],
        [194.3725, 194.4471, 194.5281, 194.2954, 194.3670, 194.1549, 194.4418,
         194.1228, 194.3979, 194.4133],
        [194.5099, 194.5853, 194.6656, 194.4337, 194.5055, 194.2932, 194.5795,
         194.2615, 194.5357, 194.5519],
        [194.4203, 194.4946, 194.5765, 194.3444, 194.4145, 194.2030, 194.4898,
         194.1716, 194.4458, 194.4622],
        [194.6656, 194.7404, 194.8190, 194.5887, 194.6600, 194.4486, 194.7346,
         194.4165, 194.6912, 194.7066],
        [194.5718, 194

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.7859, 194.5455, 194.3542, 194.7187, 195.3599, 194.7348, 194.8663,
         194.4791, 194.3708, 195.5930],
        [194.4462, 194.2070, 194.0165, 194.3799, 195.0213, 194.3944, 194.5276,
         194.1402, 194.0332, 195.2534],
        [194.5894, 194.3496, 194.1577, 194.5221, 195.1636, 194.5390, 194.6711,
         194.2825, 194.1734, 195.3963],
        [194.5958, 194.3533, 194.1625, 194.5283, 195.1684, 194.5439, 194.6751,
         194.2878, 194.1779, 195.4021],
        [194.6071, 194.3672, 194.1755, 194.5399, 195.1812, 194.5562, 194.6884,
         194.3003, 194.1909, 195.4142],
        [194.4484, 194.2054, 194.0146, 194.3807, 195.0207, 194.3967, 194.5280,
         194.1400, 194.0298, 195.2544],
        [194.5686, 194.3290, 194.1371, 194.5014, 195.1429, 194.5180, 194.6503,
         194.2618, 194.1522, 195.3758],
        [194.5683, 194

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.3917, 194.0142, 194.7050, 194.2934, 194.3726, 193.6023, 194.3772,
         194.4845, 195.2910, 194.2453],
        [194.3949, 194.0143, 194.7043, 194.2934, 194.3740, 193.6030, 194.3773,
         194.4876, 195.2935, 194.2440],
        [194.1081, 193.7268, 194.4186, 194.0074, 194.0878, 193.3168, 194.0914,
         194.1995, 195.0058, 193.9606],
        [194.3156, 193.9348, 194.6261, 194.2139, 194.2947, 193.5240, 194.2982,
         194.4081, 195.2142, 194.1656],
        [194.3740, 193.9922, 194.6834, 194.2724, 194.3523, 193.5825, 194.3561,
         194.4649, 195.2719, 194.2242],
        [194.3935, 194.0132, 194.7015, 194.2946, 194.3732, 193.6024, 194.3769,
         194.4849, 195.2911, 194.2447],
        [194.3634, 193.9824, 194.6721, 194.2625, 194.3424, 193.5717, 194.3458,
         194.4552, 195.2614, 194.2134],
        [194.6465, 194

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.2296, 194.2244, 194.5627, 194.4741, 194.2549, 193.8977, 194.2747,
         194.4243, 194.4109, 193.9261],
        [195.1574, 194.1563, 194.4923, 194.4021, 194.1832, 193.8247, 194.2035,
         194.3510, 194.3419, 193.8548],
        [195.0126, 194.0084, 194.3465, 194.2563, 194.0373, 193.6783, 194.0578,
         194.2065, 194.1957, 193.7091],
        [195.0456, 194.0415, 194.3795, 194.2893, 194.0703, 193.7110, 194.0909,
         194.2391, 194.2289, 193.7419],
        [194.9433, 193.9410, 194.2777, 194.1871, 193.9684, 193.6093, 193.9887,
         194.1367, 194.1267, 193.6402],
        [194.9243, 193.9209, 194.2583, 194.1675, 193.9489, 193.5892, 193.9695,
         194.1175, 194.1076, 193.6207],
        [195.1131, 194.1091, 194.4467, 194.3573, 194.1383, 193.7802, 194.1583,
         194.3071, 194.2953, 193.8097],
        [194.9388, 193

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
Epoch: 318 Training Loss: 14.011439323425293 Training Accuracy: 0.1
	Learning_rate: 1e-06
tensor([[194.7402, 193.8706, 194.0829, 195.1059, 194.0614, 193.8997, 194.2823,
         193.6205, 194.2036, 194.0366],
        [194.8719, 194.0021, 194.2144, 195.2384, 194.1937, 194.0320, 194.4135,
         193.7533, 194.3355, 194.1684],
        [194.7054, 193.8359, 194.0470, 195.0708, 194.0265, 193.8646, 194.2477,
         193.5852, 194.1677, 194.0017],
        [194.7291, 193.8590, 194.0719, 195.0949, 194.0500, 193.8884, 194.2705,
         193.6094, 194.1928, 194.0253],
        [194.7307, 193.8610, 194.0734, 195.0956, 194.0512, 193.8892, 194.2727,
         193.6101, 194.1933, 194.0268],
        [194.5093, 193.6382, 193.8510, 194.8737, 193.8277, 193.6660, 194.0494,
         193.3866, 193.9710, 193.8044],
        [194.7529, 193.8826, 194.0943, 195.1196, 194.0741, 193.9128, 194.2939,
     

tensor([[195.0620, 194.2398, 194.2637, 193.8424, 194.3203, 194.9039, 194.9862,
         193.9620, 194.0998, 194.3385],
        [195.0809, 194.2580, 194.2811, 193.8607, 194.3361, 194.9230, 195.0022,
         193.9803, 194.1179, 194.3549],
        [195.3093, 194.4870, 194.5109, 194.0908, 194.5649, 195.1509, 195.2320,
         194.2100, 194.3470, 194.5847],
        [195.1334, 194.3109, 194.3351, 193.9136, 194.3918, 194.9751, 195.0576,
         194.0337, 194.1706, 194.4100],
        [195.1373, 194.3140, 194.3377, 193.9174, 194.3948, 194.9795, 195.0608,
         194.0370, 194.1738, 194.4123],
        [195.3006, 194.4785, 194.5019, 194.0836, 194.5562, 195.1433, 195.2241,
         194.2015, 194.3391, 194.5755],
        [195.1374, 194.3138, 194.3370, 193.9165, 194.3931, 194.9793, 195.0588,
         194.0367, 194.1732, 194.4113],
        [195.0003, 194.1781, 194.2012, 193.7815, 194.2572, 194.8431, 194.9236,
         193.9000, 194.0388, 194.2754],
        [195.0117, 194.1884, 194.2125, 193.7892,

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[195.1144, 194.4487, 194.4145, 194.3555, 195.0400, 193.8128, 194.1608,
         193.7875, 194.0376, 194.3602],
        [194.8975, 194.2326, 194.1986, 194.1390, 194.8230, 193.5958, 193.9440,
         193.5698, 193.8199, 194.1425],
        [194.8756, 194.2118, 194.1788, 194.1174, 194.8017, 193.5747, 193.9227,
         193.5489, 193.7980, 194.1208],
        [194.9411, 194.2759, 194.2415, 194.1820, 194.8665, 193.6387, 193.9875,
         193.6136, 193.8637, 194.1857],
        [195.0358, 194.3716, 194.3383, 194.2769, 194.9622, 193.7347, 194.0830,
         193.7103, 193.9589, 194.2814],
        [195.0422, 194.3785, 194.3458, 194.2842, 194.9686, 193.7412, 194.0898,
         193.7165, 193.9649, 194.2879],
        [194.9979, 194.3328, 194.2987, 194.2389, 194.9237, 193.6964, 194.0445,
         193.6712, 193.9210, 194.2434],
        [194.8733, 194

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.3584, 194.1429, 194.0938, 193.4292, 193.8796, 194.3120, 194.1995,
         194.2601, 194.7510, 194.9568],
        [193.2337, 194.0155, 193.9670, 193.3036, 193.7547, 194.1843, 194.0729,
         194.1330, 194.6241, 194.8314],
        [193.2264, 194.0069, 193.9585, 193.2965, 193.7475, 194.1758, 194.0648,
         194.1242, 194.6163, 194.8241],
        [193.3336, 194.1169, 194.0701, 193.4056, 193.8563, 194.2867, 194.1755,
         194.2337, 194.7274, 194.9335],
        [193.3938, 194.1759, 194.1257, 193.4631, 193.9136, 194.3440, 194.2318,
         194.2937, 194.7828, 194.9901],
        [193.5890, 194.3719, 194.3228, 193.6600, 194.1101, 194.5404, 194.4284,
         194.4894, 194.9799, 195.1866],
        [193.3147, 194.0976, 194.0480, 193.3858, 193.8362, 194.2668, 194.1547,
         194.2142, 194.7070, 194.9134],
        [193.2504, 194

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.7323, 194.9055, 193.5255, 193.9344, 193.7877, 194.1472, 194.1510,
         194.9852, 194.0041, 194.0668],
        [193.8815, 195.0537, 193.6769, 194.0866, 193.9377, 194.2957, 194.3002,
         195.1345, 194.1511, 194.2176],
        [193.7315, 194.9033, 193.5252, 193.9350, 193.7874, 194.1454, 194.1498,
         194.9858, 194.0015, 194.0652],
        [193.7735, 194.9462, 193.5678, 193.9748, 193.8300, 194.1891, 194.1924,
         195.0265, 194.0456, 194.1081],
        [193.7280, 194.9007, 193.5221, 193.9293, 193.7845, 194.1437, 194.1468,
         194.9809, 194.0001, 194.0625],
        [193.7071, 194.8804, 193.5005, 193.9091, 193.7626, 194.1222, 194.1258,
         194.9601, 193.9790, 194.0417],
        [193.6966, 194.8697, 193.4908, 193.8989, 193.7526, 194.1121, 194.1154,
         194.9494, 193.9682, 194.0318],
        [193.7164, 194

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.4777, 194.1523, 194.4601, 194.2602, 194.4599, 193.7997, 193.9517,
         193.5977, 193.7467, 194.4976],
        [194.3371, 194.0126, 194.3208, 194.1212, 194.3199, 193.6592, 193.8125,
         193.4584, 193.6072, 194.3590],
        [194.1575, 193.8302, 194.1396, 193.9395, 194.1390, 193.4766, 193.6301,
         193.2760, 193.4245, 194.1790],
        [194.1132, 193.7881, 194.0969, 193.8959, 194.0966, 193.4353, 193.5887,
         193.2344, 193.3813, 194.1356],
        [194.1611, 193.8337, 194.1431, 193.9429, 194.1428, 193.4805, 193.6337,
         193.2796, 193.4279, 194.1823],
        [194.1361, 193.8092, 194.1192, 193.9193, 194.1184, 193.4563, 193.6101,
         193.2554, 193.4046, 194.1578],
        [194.2291, 193.9030, 194.2111, 194.0104, 194.2110, 193.5495, 193.7023,
         193.3488, 193.4956, 194.2506],
        [194.2240, 193

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.1220, 193.9696, 193.6711, 193.8329, 194.7869, 193.4081, 194.1193,
         193.9539, 194.1418, 193.1926],
        [194.4216, 194.2705, 193.9729, 194.1335, 195.0869, 193.7106, 194.4191,
         194.2515, 194.4421, 193.4913],
        [194.1830, 194.0323, 193.7341, 193.8962, 194.8495, 193.4716, 194.1828,
         194.0135, 194.2059, 193.2537],
        [194.1293, 193.9778, 193.6788, 193.8403, 194.7943, 193.4163, 194.1266,
         193.9603, 194.1492, 193.1995],
        [194.3377, 194.1873, 193.8873, 194.0474, 195.0024, 193.6263, 194.3331,
         194.1646, 194.3553, 193.4052],
        [194.1876, 194.0361, 193.7370, 193.8979, 194.8521, 193.4747, 194.1839,
         194.0180, 194.2063, 193.2572],
        [194.3750, 194.2249, 193.9257, 194.0857, 195.0400, 193.6644, 194.3712,
         194.2029, 194.3939, 193.4434],
        [193.9851, 193.8338, 193.5328, 193.6944, 194.64

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.8840, 194.1925, 194.2111, 193.9684, 194.1958, 193.7789, 194.0504,
         194.0610, 193.6828, 194.9348],
        [193.7755, 194.0820, 194.1001, 193.8596, 194.0859, 193.6694, 193.9397,
         193.9501, 193.5729, 194.8249],
        [194.0937, 194.3990, 194.4167, 194.1767, 194.4030, 193.9870, 194.2587,
         194.2686, 193.8909, 195.1431],
        [194.0508, 194.3557, 194.3734, 194.1331, 194.3592, 193.9433, 194.2151,
         194.2252, 193.8475, 195.0997],
        [193.6017, 193.9088, 193.9264, 193.6865, 193.9121, 193.4957, 193.7682,
         193.7775, 193.3995, 194.6523],
        [193.8977, 194.2033, 194.2203, 193.9810, 194.2064, 193.7906, 194.0644,
         194.0734, 193.6949, 194.9481],
        [193.7297, 194.0370, 194.0544, 193.8143, 194.0402, 193.6239, 193.8973,
         193.9064, 193.5279, 194.7810],
        [193.9536, 194

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.8629, 193.7514, 193.4832, 193.6810, 193.2632, 193.1369, 193.9276,
         194.2294, 193.4487, 193.4804],
        [194.2413, 194.1266, 193.8624, 194.0609, 193.6426, 193.5158, 194.3072,
         194.6085, 193.8285, 193.8600],
        [194.2415, 194.1278, 193.8619, 194.0603, 193.6424, 193.5150, 194.3064,
         194.6073, 193.8277, 193.8592],
        [194.1732, 194.0600, 193.7945, 193.9924, 193.5742, 193.4472, 194.2401,
         194.5408, 193.7607, 193.7922],
        [194.2089, 194.0961, 193.8301, 194.0279, 193.6099, 193.4824, 194.2757,
         194.5763, 193.7955, 193.8274],
        [194.2858, 194.1734, 193.9071, 194.1048, 193.6866, 193.5597, 194.3526,
         194.6528, 193.8745, 193.9050],
        [194.2231, 194.1101, 193.8443, 194.0420, 193.6239, 193.4968, 194.2898,
         194.5903, 193.8104, 193.8419],
        [194.1309, 194

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.3414, 194.0870, 194.9859, 194.0484, 194.0360, 193.9630, 195.1449,
         194.1019, 194.6640, 194.0027],
        [194.0313, 193.7760, 194.6756, 193.7372, 193.7255, 193.6522, 194.8340,
         193.7923, 194.3536, 193.6927],
        [194.2399, 193.9857, 194.8860, 193.9474, 193.9352, 193.8626, 195.0434,
         194.0027, 194.5630, 193.9008],
        [193.9403, 193.6857, 194.5860, 193.6461, 193.6348, 193.5613, 194.7434,
         193.7027, 194.2631, 193.6004],
        [194.1241, 193.8699, 194.7697, 193.8308, 193.8190, 193.7457, 194.9276,
         193.8863, 194.4468, 193.7848],
        [194.0756, 193.8204, 194.7199, 193.7808, 193.7693, 193.6955, 194.8781,
         193.8355, 194.3983, 193.7354],
        [194.0021, 193.7470, 194.6467, 193.7080, 193.6964, 193.6232, 194.8050,
         193.7637, 194.3243, 193.6634],
        [194.2981, 194

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.8124, 194.2892, 193.8253, 193.4676, 194.1335, 194.6145, 194.1508,
         193.9589, 193.9627, 193.7720],
        [193.7335, 194.2124, 193.7468, 193.3880, 194.0555, 194.5347, 194.0735,
         193.8815, 193.8813, 193.6932],
        [193.5819, 194.0595, 193.5953, 193.2370, 193.9038, 194.3840, 193.9213,
         193.7284, 193.7323, 193.5414],
        [193.5624, 194.0411, 193.5758, 193.2169, 193.8846, 194.3640, 193.9028,
         193.7096, 193.7112, 193.5219],
        [193.6467, 194.1258, 193.6603, 193.3022, 193.9695, 194.4470, 193.9873,
         193.7959, 193.7938, 193.6069],
        [193.6204, 194.0971, 193.6331, 193.2751, 193.9417, 194.4227, 193.9592,
         193.7661, 193.7711, 193.5795],
        [193.9687, 194.4464, 193.9816, 193.6240, 194.2899, 194.7697, 194.3073,
         194.1169, 194.1170, 193.9288],
        [193.6375, 194

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.1926, 193.9453, 194.4012, 194.1438, 194.9875, 193.9302, 194.7562,
         193.5078, 194.0749, 193.7729],
        [193.9465, 193.7009, 194.1579, 193.8998, 194.7413, 193.6846, 194.5108,
         193.2619, 193.8289, 193.5276],
        [193.7458, 193.4993, 193.9567, 193.6998, 194.5411, 193.4844, 194.3101,
         193.0604, 193.6288, 193.3256],
        [194.0681, 193.8212, 194.2771, 194.0195, 194.8634, 193.8065, 194.6323,
         193.3832, 193.9512, 193.6490],
        [193.9458, 193.6995, 194.1571, 193.8998, 194.7405, 193.6833, 194.5097,
         193.2610, 193.8275, 193.5249],
        [193.9749, 193.7285, 194.1858, 193.9285, 194.7695, 193.7123, 194.5386,
         193.2903, 193.8567, 193.5538],
        [194.1063, 193.8608, 194.3184, 194.0608, 194.9019, 193.8448, 194.6705,
         193.4227, 193.9880, 193.6851],
        [193.9430, 193

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.2342, 193.8020, 193.9730, 193.6180, 193.4076, 193.8988, 193.5911,
         193.0587, 193.5331, 194.7139],
        [193.3023, 193.8693, 194.0409, 193.6859, 193.4749, 193.9675, 193.6559,
         193.1253, 193.6002, 194.7827],
        [193.3126, 193.8783, 194.0517, 193.6954, 193.4860, 193.9769, 193.6675,
         193.1365, 193.6102, 194.7917],
        [193.1748, 193.7426, 193.9140, 193.5584, 193.3482, 193.8397, 193.5317,
         192.9988, 193.4735, 194.6547],
        [193.3261, 193.8909, 194.0661, 193.7085, 193.5006, 193.9887, 193.6797,
         193.1519, 193.6233, 194.8047],
        [193.0849, 193.6524, 193.8258, 193.4681, 193.2587, 193.7495, 193.4393,
         192.9089, 193.3828, 194.5653],
        [193.3692, 193.9360, 194.1078, 193.7527, 193.5421, 194.0331, 193.7214,
         193.1933, 193.6669, 194.8495],
        [193.2894, 193.8563, 194.0285, 193.6728, 193.46

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.9107, 193.7510, 193.8989, 193.6532, 194.0263, 193.1554, 193.6516,
         193.1620, 193.9890, 193.5379],
        [193.8182, 193.6586, 193.8060, 193.5613, 193.9341, 193.0626, 193.5598,
         193.0696, 193.8968, 193.4449],
        [193.8426, 193.6857, 193.8316, 193.5862, 193.9613, 193.0855, 193.5841,
         193.0966, 193.9224, 193.4702],
        [194.1187, 193.9613, 194.1078, 193.8636, 194.2351, 193.3647, 193.8607,
         193.3718, 194.1982, 193.7475],
        [193.9904, 193.8320, 193.9794, 193.7328, 194.1073, 193.2346, 193.7309,
         193.2431, 194.0692, 193.6185],
        [193.8414, 193.6820, 193.8293, 193.5839, 193.9576, 193.0855, 193.5827,
         193.0932, 193.9200, 193.4683],
        [193.9560, 193.7977, 193.9443, 193.7011, 194.0719, 193.2017, 193.6988,
         193.2081, 194.0352, 193.5835],
        [193.6211, 193

       grad_fn=<CdistBackward0>)
tensor([9, 8, 9, 9, 9, 8, 8, 9, 9, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.1142, 194.1071, 193.7632, 193.9995, 193.9156, 193.8673, 194.9084,
         193.7525, 194.2446, 193.7287],
        [194.2063, 194.1989, 193.8551, 194.0917, 194.0090, 193.9600, 195.0013,
         193.8466, 194.3377, 193.8210],
        [193.8615, 193.8551, 193.5105, 193.7474, 193.6626, 193.6168, 194.6566,
         193.5005, 193.9932, 193.4759],
        [193.9216, 193.9157, 193.5713, 193.8071, 193.7233, 193.6758, 194.7173,
         193.5615, 194.0540, 193.5357],
        [193.7530, 193.7455, 193.4013, 193.6387, 193.5562, 193.5092, 194.5475,
         193.3930, 193.8832, 193.3671],
        [194.1629, 194.1555, 193.8120, 194.0483, 193.9654, 193.9164, 194.9575,
         193.8026, 194.2938, 193.7775],
        [193.8674, 193.8597, 193.5163, 193.7525, 193.6713, 193.6219, 194.6615,
         193.5077, 193.9967, 193.4813],
        [194.1979, 194

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.0379, 193.7556, 193.2381, 193.4869, 193.9750, 193.8938, 193.8965,
         193.7663, 194.0542, 193.0616],
        [192.9288, 193.6508, 193.1306, 193.3776, 193.8704, 193.7853, 193.7877,
         193.6601, 193.9447, 192.9550],
        [192.9500, 193.6696, 193.1508, 193.3992, 193.8889, 193.8061, 193.8092,
         193.6796, 193.9656, 192.9751],
        [193.0954, 193.8176, 193.2977, 193.5441, 194.0363, 193.9509, 193.9545,
         193.8268, 194.1105, 193.1233],
        [193.0283, 193.7489, 193.2292, 193.4772, 193.9677, 193.8830, 193.8876,
         193.7578, 194.0432, 193.0529],
        [193.2294, 193.9485, 193.4307, 193.6780, 194.1666, 194.0840, 194.0879,
         193.9587, 194.2447, 193.2557],
        [192.9761, 193.6988, 193.1783, 193.4249, 193.9182, 193.8325, 193.8351,
         193.7079, 193.9917, 193.0033],
        [193.1517, 193

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.7536, 194.0204, 193.5411, 192.9253, 193.4781, 192.9514, 193.4245,
         193.4985, 193.6372, 194.6056],
        [193.5517, 193.8210, 193.3419, 192.7259, 193.2781, 192.7503, 193.2246,
         193.2985, 193.4352, 194.4048],
        [193.6711, 193.9388, 193.4596, 192.8438, 193.3964, 192.8693, 193.3428,
         193.4167, 193.5546, 194.5234],
        [193.8359, 194.1033, 193.6261, 193.0110, 193.5638, 193.0349, 193.5099,
         193.5832, 193.7196, 194.6895],
        [193.9500, 194.2198, 193.7408, 193.1261, 193.6790, 193.1511, 193.6248,
         193.6977, 193.8332, 194.8037],
        [193.8050, 194.0723, 193.5929, 192.9772, 193.5297, 193.0047, 193.4769,
         193.5500, 193.6879, 194.6565],
        [193.8241, 194.0911, 193.6132, 192.9980, 193.5508, 193.0225, 193.4968,
         193.5705, 193.7078, 194.6772],
        [193.8736, 194.1429, 193.6634, 193.0483, 193.6011, 193.0739, 193.5471,
         193.6203, 193.7568, 194

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.9792, 193.7516, 193.8793, 193.8812, 193.8975, 194.3607, 194.6235,
         193.8474, 193.6797, 193.7059],
        [194.0553, 193.8275, 193.9550, 193.9579, 193.9754, 194.4367, 194.6992,
         193.9245, 193.7582, 193.7844],
        [193.8810, 193.6532, 193.7804, 193.7830, 193.7997, 194.2627, 194.5250,
         193.7494, 193.5809, 193.6068],
        [193.8770, 193.6489, 193.7778, 193.7804, 193.7980, 194.2578, 194.5210,
         193.7458, 193.5807, 193.6068],
        [194.1946, 193.9671, 194.0954, 194.0972, 194.1143, 194.5761, 194.8389,
         194.0636, 193.8977, 193.9244],
        [193.8478, 193.6200, 193.7477, 193.7500, 193.7664, 194.2293, 194.4919,
         193.7159, 193.5480, 193.5739],
        [193.9296, 193.7017, 193.8312, 193.8331, 193.8503, 194.3106, 194.5739,
         193.7979, 193.6327, 193.6593],
        [193.8567, 193

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.3036, 193.5752, 193.7206, 193.3595, 194.0451, 193.5041, 193.8163,
         193.7379, 194.5516, 193.5901],
        [193.2425, 193.5120, 193.6591, 193.2988, 193.9843, 193.4410, 193.7529,
         193.6739, 194.4915, 193.5285],
        [193.2677, 193.5391, 193.6848, 193.3241, 194.0083, 193.4677, 193.7800,
         193.7019, 194.5156, 193.5538],
        [193.2990, 193.5709, 193.7161, 193.3559, 194.0403, 193.4990, 193.8111,
         193.7333, 194.5471, 193.5843],
        [193.0142, 193.2843, 193.4309, 193.0691, 193.7554, 193.2126, 193.5257,
         193.4473, 194.2630, 193.3016],
        [193.4135, 193.6833, 193.8305, 193.4698, 194.1541, 193.6134, 193.9250,
         193.8455, 194.6618, 193.7003],
        [193.0616, 193.3324, 193.4781, 193.1167, 193.8038, 193.2603, 193.5732,
         193.4952, 194.3105, 193.3482],
        [193.2994, 193

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.5538, 193.0485, 193.7805, 193.2404, 193.7206, 193.6994, 193.4235,
         193.5862, 193.4777, 193.5690],
        [193.2036, 192.6999, 193.4321, 192.8914, 193.3706, 193.3492, 193.0721,
         193.2368, 193.1272, 193.2183],
        [193.4352, 192.9321, 193.6640, 193.1226, 193.6028, 193.5827, 193.3050,
         193.4689, 193.3604, 193.4506],
        [193.5388, 193.0340, 193.7657, 193.2246, 193.7046, 193.6844, 193.4080,
         193.5706, 193.4616, 193.5515],
        [193.6949, 193.1899, 193.9218, 193.3809, 193.8618, 193.8415, 193.5652,
         193.7271, 193.6192, 193.7095],
        [193.3788, 192.8741, 193.6061, 193.0653, 193.5449, 193.5239, 193.2473,
         193.4109, 193.3013, 193.3919],
        [193.4578, 192.9533, 193.6852, 193.1440, 193.6237, 193.6032, 193.3266,
         193.4898, 193.3804, 193.4704],
        [193.4921, 192

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.3462, 193.4057, 193.7205, 193.8331, 193.2581, 193.2913, 193.0292,
         193.8955, 193.6663, 193.8213],
        [193.4417, 193.5021, 193.8173, 193.9293, 193.3545, 193.3876, 193.1249,
         193.9924, 193.7628, 193.9178],
        [193.3991, 193.4546, 193.7686, 193.8833, 193.3071, 193.3430, 193.0820,
         193.9483, 193.7189, 193.8699],
        [193.0179, 193.0781, 193.3941, 193.5056, 192.9285, 192.9619, 192.6998,
         193.5695, 193.3417, 193.4947],
        [193.3620, 193.4231, 193.7386, 193.8501, 193.2753, 193.3078, 193.0451,
         193.9125, 193.6833, 193.8391],
        [193.3600, 193.4202, 193.7348, 193.8470, 193.2712, 193.3061, 193.0437,
         193.9131, 193.6835, 193.8353],
        [193.3639, 193.4233, 193.7383, 193.8508, 193.2757, 193.3091, 193.0469,
         193.9135, 193.6842, 193.8390],
        [193.3790, 193

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.4419, 193.4796, 192.9670, 193.9138, 194.1338, 192.7517, 194.0175,
         193.3913, 193.6460, 193.0153],
        [193.3900, 193.4262, 192.9133, 193.8625, 194.0816, 192.6995, 193.9641,
         193.3389, 193.5913, 192.9630],
        [193.5022, 193.5392, 193.0268, 193.9746, 194.1937, 192.8121, 194.0771,
         193.4514, 193.7057, 193.0753],
        [193.2143, 193.2516, 192.7378, 193.6850, 193.9071, 192.5232, 193.7897,
         193.1632, 193.4162, 192.7869],
        [193.4243, 193.4617, 192.9491, 193.8967, 194.1161, 192.7342, 193.9992,
         193.3737, 193.6272, 192.9982],
        [193.5228, 193.5608, 193.0478, 193.9914, 194.2154, 192.8306, 194.0992,
         193.4718, 193.7263, 193.0942],
        [193.5101, 193.5458, 193.0331, 193.9820, 194.2016, 192.8191, 194.0841,
         193.4588, 193.7111, 193.0821],
        [193.5360, 193

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[194.0873, 192.7829, 193.1566, 193.5173, 193.7621, 192.9508, 192.7754,
         193.4371, 193.6250, 193.4542],
        [193.9287, 192.6260, 192.9980, 193.3595, 193.6053, 192.7920, 192.6152,
         193.2777, 193.4683, 193.2965],
        [194.1823, 192.8793, 193.2524, 193.6125, 193.8591, 193.0455, 192.8693,
         193.5296, 193.7220, 193.5494],
        [194.1491, 192.8476, 193.2206, 193.5812, 193.8274, 193.0135, 192.8368,
         193.4970, 193.6897, 193.5185],
        [194.1991, 192.8972, 193.2704, 193.6309, 193.8772, 193.0636, 192.8868,
         193.5471, 193.7392, 193.5678],
        [194.0906, 192.7898, 193.1637, 193.5246, 193.7688, 192.9561, 192.7804,
         193.4402, 193.6315, 193.4632],
        [194.1470, 192.8441, 193.2187, 193.5788, 193.8229, 193.0112, 192.8367,
         193.4961, 193.6864, 193.5173],
        [193.9794, 192.6781, 193.0515, 193.4131, 193.65

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.8089, 192.9295, 193.4683, 193.1569, 194.3815, 193.4420, 193.6078,
         193.7196, 193.5640, 193.3329],
        [192.8463, 192.9656, 193.5065, 193.1945, 194.4192, 193.4779, 193.6444,
         193.7574, 193.6006, 193.3703],
        [192.6411, 192.7626, 193.3023, 192.9902, 194.2162, 193.2757, 193.4402,
         193.5542, 193.3978, 193.1673],
        [192.8458, 192.9638, 193.5068, 193.1942, 194.4185, 193.4760, 193.6438,
         193.7575, 193.5995, 193.3693],
        [192.8285, 192.9505, 193.4894, 193.1777, 194.4031, 193.4636, 193.6287,
         193.7414, 193.5848, 193.3570],
        [192.7319, 192.8522, 193.3934, 193.0811, 194.3069, 193.3650, 193.5300,
         193.6449, 193.4874, 193.2582],
        [192.7557, 192.8765, 193.4165, 193.1045, 194.3290, 193.3895, 193.5566,
         193.6683, 193.5115, 193.2814],
        [192.8307, 192

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.6597, 193.1989, 193.7093, 193.5465, 193.5333, 192.7561, 193.4476,
         193.3478, 194.5008, 193.6853],
        [193.6298, 193.1672, 193.6752, 193.5099, 193.4967, 192.7241, 193.4131,
         193.3170, 194.4666, 193.6500],
        [193.7837, 193.3236, 193.8336, 193.6697, 193.6571, 192.8801, 193.5711,
         193.4718, 194.6246, 193.8091],
        [193.4913, 193.0275, 193.5363, 193.3712, 193.3574, 192.5852, 193.2747,
         193.1782, 194.3280, 193.5115],
        [193.8625, 193.4015, 193.9106, 193.7453, 193.7325, 192.9582, 193.6491,
         193.5509, 194.7025, 193.8857],
        [193.6470, 193.1855, 193.6961, 193.5321, 193.5189, 192.7428, 193.4344,
         193.3349, 194.4878, 193.6719],
        [193.5689, 193.1071, 193.6162, 193.4522, 193.4390, 192.6640, 193.3535,
         193.2562, 194.4070, 193.5916],
        [193.8699, 193

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.2162, 193.2391, 193.3786, 193.1939, 193.2898, 192.9198, 193.2411,
         193.2730, 194.3158, 193.2696],
        [193.3114, 193.3365, 193.4717, 193.2913, 193.3847, 193.0165, 193.3391,
         193.3678, 194.4128, 193.3648],
        [193.1203, 193.1442, 193.2811, 193.1002, 193.1938, 192.8246, 193.1472,
         193.1788, 194.2223, 193.1743],
        [192.9227, 192.9470, 193.0851, 192.9016, 192.9966, 192.6259, 192.9485,
         192.9814, 194.0242, 192.9767],
        [193.2664, 193.2894, 193.4279, 193.2451, 193.3397, 192.9705, 193.2921,
         193.3242, 194.3671, 193.3200],
        [193.1416, 193.1646, 193.3040, 193.1197, 193.2149, 192.8449, 193.1664,
         193.2000, 194.2422, 193.1950],
        [193.4749, 193.4983, 193.6355, 193.4537, 193.5480, 193.1797, 193.5014,
         193.5306, 194.5749, 193.5280],
        [193.2754, 193

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.5783, 192.9959, 193.2456, 194.2854, 192.7700, 193.2173, 193.4067,
         193.6252, 193.3531, 193.4883],
        [193.5975, 193.0152, 193.2649, 194.3043, 192.7891, 193.2362, 193.4256,
         193.6442, 193.3725, 193.5074],
        [193.2097, 192.6289, 192.8772, 193.9179, 192.4010, 192.8496, 193.0409,
         193.2584, 192.9833, 193.1227],
        [193.6676, 193.0868, 193.3362, 194.3764, 192.8610, 193.3089, 193.4992,
         193.7177, 193.4436, 193.5809],
        [193.5155, 192.9337, 193.1819, 194.2235, 192.7074, 193.1553, 193.3457,
         193.5637, 193.2900, 193.4280],
        [193.5721, 192.9915, 193.2423, 194.2800, 192.7649, 193.2125, 193.4030,
         193.6216, 193.3479, 193.4839],
        [193.7247, 193.1435, 193.3934, 194.4324, 192.9174, 193.3649, 193.5546,
         193.7734, 193.5009, 193.6364],
        [193.4536, 192.8726, 193.1235, 194.1614, 192.64

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.5480, 194.3879, 192.8624, 193.5383, 194.0279, 193.5688, 193.5742,
         192.7403, 193.4416, 193.4982],
        [193.4285, 194.2652, 192.7384, 193.4166, 193.9038, 193.4460, 193.4514,
         192.6195, 193.3206, 193.3746],
        [193.5177, 194.3544, 192.8273, 193.5056, 193.9927, 193.5348, 193.5401,
         192.7088, 193.4098, 193.4636],
        [193.4288, 194.2668, 192.7381, 193.4179, 193.9055, 193.4476, 193.4518,
         192.6182, 193.3206, 193.3766],
        [193.4397, 194.2765, 192.7497, 193.4271, 193.9152, 193.4571, 193.4626,
         192.6304, 193.3318, 193.3860],
        [193.7018, 194.5414, 193.0165, 193.6919, 194.1806, 193.7216, 193.7274,
         192.8955, 193.5957, 193.6509],
        [193.4445, 194.2843, 192.7567, 193.4338, 193.9242, 193.4652, 193.4698,
         192.6342, 193.3370, 193.3950],
        [193.4796, 194

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.6322, 192.7629, 193.6197, 193.4429, 193.5595, 193.3503, 193.2336,
         193.0882, 193.5047, 193.3953],
        [193.4826, 192.6138, 193.4704, 193.2930, 193.4110, 193.2004, 193.0834,
         192.9383, 193.3546, 193.2453],
        [193.6763, 192.8081, 193.6646, 193.4876, 193.6050, 193.3943, 193.2781,
         193.1324, 193.5491, 193.4403],
        [193.7531, 192.8857, 193.7415, 193.5631, 193.6825, 193.4718, 193.3551,
         193.2095, 193.6261, 193.5161],
        [193.4317, 192.5606, 193.4186, 193.2397, 193.3598, 193.1488, 193.0312,
         192.8861, 193.3025, 193.1927],
        [193.3506, 192.4786, 193.3369, 193.1590, 193.2774, 193.0674, 192.9496,
         192.8049, 193.2212, 193.1115],
        [193.5041, 192.6316, 193.4901, 193.3129, 193.4303, 193.2208, 193.1035,
         192.9583, 193.3750, 193.2659],
        [193.5223, 192

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.4492, 194.4080, 193.4709, 194.0071, 193.1642, 193.3665, 193.2881,
         193.6959, 194.2555, 193.4262],
        [193.2297, 194.1870, 193.2496, 193.7872, 192.9434, 193.1446, 193.0673,
         193.4746, 194.0350, 193.2063],
        [193.1332, 194.0909, 193.1510, 193.6874, 192.8461, 193.0490, 192.9707,
         193.3768, 193.9371, 193.1085],
        [193.2523, 194.2101, 193.2706, 193.8074, 192.9656, 193.1678, 193.0902,
         193.4963, 194.0567, 193.2281],
        [193.2456, 194.2029, 193.2653, 193.8031, 192.9593, 193.1604, 193.0833,
         193.4905, 194.0509, 193.2221],
        [193.2849, 194.2436, 193.3066, 193.8450, 193.0000, 193.2002, 193.1241,
         193.5323, 194.0919, 193.2627],
        [192.9357, 193.8929, 192.9551, 193.4917, 192.6487, 192.8517, 192.7721,
         193.1799, 193.7398, 192.9113],
        [193.4000, 194

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.5276, 193.2807, 193.7916, 193.1296, 193.4292, 192.9692, 193.1645,
         193.3228, 193.3041, 193.1461],
        [192.4274, 193.1774, 193.6899, 193.0282, 193.3303, 192.8685, 193.0629,
         193.2222, 193.2031, 193.0445],
        [192.5165, 193.2669, 193.7769, 193.1140, 193.4189, 192.9576, 193.1507,
         193.3108, 193.2894, 193.1346],
        [192.4135, 193.1669, 193.6770, 193.0138, 193.3150, 192.8543, 193.0497,
         193.2080, 193.1885, 193.0323],
        [192.5410, 193.2951, 193.8044, 193.1418, 193.4420, 192.9826, 193.1775,
         193.3360, 193.3163, 193.1600],
        [192.5091, 193.2575, 193.7687, 193.1075, 193.4122, 192.9510, 193.1428,
         193.3039, 193.2827, 193.1252],
        [192.4517, 193.2011, 193.7126, 193.0513, 193.3546, 192.8933, 193.0862,
         193.2466, 193.2262, 193.0682],
        [192.6185, 193

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.1588, 193.2645, 193.3237, 192.4750, 193.3773, 193.0545, 193.3031,
         193.4518, 192.9014, 193.0223],
        [193.1808, 193.2886, 193.3475, 192.4993, 193.3995, 193.0779, 193.3266,
         193.4761, 192.9253, 193.0461],
        [193.4256, 193.5337, 193.5934, 192.7445, 193.6432, 193.3234, 193.5717,
         193.7212, 193.1737, 193.2925],
        [193.0147, 193.1208, 193.1804, 192.3326, 193.2339, 192.9117, 193.1608,
         193.3094, 192.7583, 192.8789],
        [193.2571, 193.3612, 193.4218, 192.5735, 193.4746, 193.1536, 193.4026,
         193.5505, 193.0019, 193.1211],
        [193.0992, 193.2062, 193.2658, 192.4175, 193.3178, 192.9966, 193.2453,
         193.3942, 192.8446, 192.9644],
        [193.0299, 193.1360, 193.1955, 192.3473, 193.2488, 192.9266, 193.1752,
         193.3238, 192.7737, 192.8939],
        [193.2841, 193

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.1434, 193.9521, 192.9021, 192.8045, 192.9884, 192.8484, 193.0874,
         192.9174, 192.8187, 192.8506],
        [193.2938, 194.1025, 193.0525, 192.9554, 193.1390, 192.9994, 193.2377,
         193.0692, 192.9703, 193.0016],
        [193.3226, 194.1314, 193.0813, 192.9843, 193.1682, 193.0285, 193.2668,
         193.0981, 192.9995, 193.0305],
        [193.6248, 194.4337, 193.3837, 193.2863, 193.4698, 193.3303, 193.5682,
         193.3999, 193.3010, 193.3324],
        [193.2782, 194.0864, 193.0365, 192.9390, 193.1222, 192.9832, 193.2217,
         193.0531, 192.9535, 192.9856],
        [193.3254, 194.1343, 193.0843, 192.9875, 193.1710, 193.0310, 193.2693,
         193.1006, 193.0019, 193.0331],
        [193.4426, 194.2515, 193.2018, 193.1032, 193.2880, 193.1479, 193.3866,
         193.2163, 193.1186, 193.1499],
        [193.2962, 194

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.4977, 193.1571, 193.0506, 193.1525, 193.3819, 193.1702, 193.3679,
         193.4611, 192.7975, 194.1148],
        [193.4284, 193.0834, 192.9795, 193.0818, 193.3113, 193.1008, 193.2981,
         193.3908, 192.7249, 194.0443],
        [193.4380, 193.0972, 192.9912, 193.0934, 193.3234, 193.1113, 193.3083,
         193.4018, 192.7382, 194.0552],
        [193.2501, 192.9067, 192.8021, 192.9040, 193.1333, 192.9218, 193.1192,
         193.2125, 192.5477, 193.8665],
        [193.4168, 193.0774, 192.9699, 193.0719, 193.3022, 193.0895, 193.2872,
         193.3807, 192.7177, 194.0339],
        [193.3272, 192.9850, 192.8795, 192.9816, 193.2117, 192.9998, 193.1971,
         193.2904, 192.6260, 193.9438],
        [193.3629, 193.0183, 192.9144, 193.0163, 193.2446, 193.0341, 193.2318,
         193.3246, 192.6593, 193.9792],
        [193.3160, 192

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.2739, 193.8032, 193.4278, 193.3769, 193.2745, 193.3166, 192.5222,
         193.0162, 193.4474, 192.6893],
        [193.1637, 193.6934, 193.3188, 193.2667, 193.1649, 193.2062, 192.4115,
         192.9069, 193.3376, 192.5790],
        [193.2267, 193.7580, 193.3825, 193.3302, 193.2301, 193.2695, 192.4744,
         192.9715, 193.4010, 192.6421],
        [193.3984, 193.9288, 193.5509, 193.5016, 193.3997, 193.4418, 192.6465,
         193.1397, 193.5716, 192.8132],
        [193.2384, 193.7671, 193.3925, 193.3412, 193.2383, 193.2811, 192.4866,
         192.9806, 193.4119, 192.6535],
        [193.1521, 193.6836, 193.3073, 193.2559, 193.1554, 193.1969, 192.3991,
         192.8959, 193.3267, 192.5655],
        [193.2045, 193.7344, 193.3586, 193.3076, 193.2057, 193.2480, 192.4522,
         192.9468, 193.3783, 192.6190],
        [193.4817, 194

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.4658, 193.8925, 193.2184, 192.9339, 193.5354, 193.1316, 194.0931,
         193.3879, 193.3480, 193.3090],
        [193.2309, 193.6572, 192.9858, 192.6997, 193.3014, 192.8975, 193.8589,
         193.1528, 193.1127, 193.0751],
        [193.5146, 193.9411, 193.2676, 192.9830, 193.5846, 193.1810, 194.1420,
         193.4367, 193.3965, 193.3584],
        [193.3094, 193.7383, 193.0642, 192.7783, 193.3796, 192.9769, 193.9389,
         193.2332, 193.1925, 193.1533],
        [193.3421, 193.7681, 193.0958, 192.8105, 193.4131, 193.0085, 193.9696,
         193.2639, 193.2234, 193.1870],
        [193.1727, 193.5999, 192.9272, 192.6413, 193.2437, 192.8394, 193.8011,
         193.0955, 193.0548, 193.0173],
        [193.4575, 193.8849, 193.2123, 192.9267, 193.5269, 193.1246, 194.0862,
         193.3797, 193.3400, 193.3007],
        [193.3671, 193.7950, 193.1198, 192.8351, 193.43

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.8819, 192.8334, 193.3181, 193.1528, 193.0069, 192.6127, 192.5321,
         192.9049, 193.4726, 192.9237],
        [192.9527, 192.9037, 193.3913, 193.2231, 193.0768, 192.6833, 192.6020,
         192.9778, 193.5466, 192.9951],
        [192.8850, 192.8373, 193.3233, 193.1559, 193.0100, 192.6156, 192.5339,
         192.9089, 193.4791, 192.9259],
        [192.8640, 192.8156, 193.3010, 193.1349, 192.9896, 192.5948, 192.5138,
         192.8878, 193.4555, 192.9062],
        [192.8385, 192.7896, 193.2768, 193.1092, 192.9624, 192.5699, 192.4883,
         192.8634, 193.4315, 192.8810],
        [192.9764, 192.9272, 193.4138, 193.2469, 193.1018, 192.7069, 192.6261,
         193.0009, 193.5684, 193.0196],
        [192.8987, 192.8499, 193.3372, 193.1693, 193.0228, 192.6296, 192.5482,
         192.9235, 193.4924, 192.9408],
        [192.9771, 192

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.3926, 192.7367, 193.4035, 193.0290, 193.1294, 193.2357, 193.0420,
         192.9186, 192.9565, 193.3323],
        [193.3133, 192.6572, 193.3250, 192.9509, 193.0517, 193.1571, 192.9639,
         192.8388, 192.8773, 193.2540],
        [193.3479, 192.6923, 193.3596, 192.9842, 193.0854, 193.1913, 192.9967,
         192.8739, 192.9130, 193.2870],
        [193.3139, 192.6579, 193.3244, 192.9523, 193.0513, 193.1586, 192.9632,
         192.8392, 192.8782, 193.2541],
        [193.1756, 192.5203, 193.1866, 192.8134, 192.9122, 193.0196, 192.8251,
         192.7006, 192.7396, 193.1154],
        [193.3455, 192.6894, 193.3566, 192.9828, 193.0829, 193.1891, 192.9951,
         192.8711, 192.9094, 193.2856],
        [193.1408, 192.4859, 193.1546, 192.7772, 192.8791, 192.9837, 192.7902,
         192.6660, 192.7063, 193.0797],
        [193.2326, 192

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.9208, 192.9470, 193.5723, 192.9569, 193.0974, 192.8208, 192.9549,
         192.7729, 192.5902, 192.7605],
        [192.8983, 192.9214, 193.5466, 192.9350, 193.0739, 192.7969, 192.9322,
         192.7474, 192.5674, 192.7362],
        [192.9039, 192.9300, 193.5562, 192.9407, 193.0809, 192.8044, 192.9407,
         192.7557, 192.5745, 192.7449],
        [192.8618, 192.8887, 193.5149, 192.8987, 193.0391, 192.7625, 192.8988,
         192.7146, 192.5324, 192.7034],
        [192.9334, 192.9563, 193.5802, 192.9690, 193.1085, 192.8315, 192.9637,
         192.7823, 192.6011, 192.7691],
        [192.8536, 192.8758, 193.5002, 192.8902, 193.0289, 192.7515, 192.8856,
         192.7019, 192.5220, 192.6899],
        [192.8896, 192.9158, 193.5417, 192.9263, 193.0665, 192.7898, 192.9257,
         192.7415, 192.5599, 192.7304],
        [192.9104, 192

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.3260, 192.7761, 192.9508, 192.6883, 193.7547, 193.1684, 192.8102,
         192.3156, 193.6460, 192.3407],
        [193.2903, 192.7433, 192.9182, 192.6573, 193.7214, 193.1359, 192.7775,
         192.2818, 193.6134, 192.3105],
        [193.5153, 192.9678, 193.1409, 192.8802, 193.9451, 193.3589, 193.0015,
         192.5058, 193.8365, 192.5340],
        [193.3074, 192.7592, 192.9349, 192.6747, 193.7385, 193.1529, 192.7941,
         192.2988, 193.6312, 192.3275],
        [193.3808, 192.8323, 193.0073, 192.7465, 193.8110, 193.2253, 192.8668,
         192.3717, 193.7032, 192.3994],
        [193.2438, 192.6946, 192.8707, 192.6102, 193.6745, 193.0884, 192.7299,
         192.2347, 193.5676, 192.2627],
        [193.2935, 192.7449, 192.9187, 192.6557, 193.7224, 193.1362, 192.7783,
         192.2831, 193.6128, 192.3083],
        [193.3061, 192

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.2619, 192.7723, 192.5851, 193.1072, 193.2428, 193.0664, 193.0872,
         193.0870, 193.1281, 193.1664],
        [192.3935, 192.9030, 192.7160, 193.2374, 193.3741, 193.1989, 193.2188,
         193.2174, 193.2581, 193.2982],
        [192.2980, 192.8097, 192.6236, 193.1460, 193.2838, 193.1053, 193.1266,
         193.1263, 193.1682, 193.2046],
        [192.5490, 193.0586, 192.8716, 193.3932, 193.5306, 193.3523, 193.3740,
         193.3730, 193.4137, 193.4518],
        [192.3123, 192.8241, 192.6372, 193.1602, 193.2974, 193.1188, 193.1398,
         193.1400, 193.1817, 193.2182],
        [192.3309, 192.8418, 192.6555, 193.1776, 193.3143, 193.1348, 193.1579,
         193.1575, 193.1993, 193.2349],
        [192.2343, 192.7446, 192.5572, 193.0793, 193.2151, 193.0414, 193.0598,
         193.0594, 193.1000, 193.1407],
        [192.4859, 192

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.2329, 193.0627, 193.0578, 192.2399, 193.0661, 192.3991, 193.0988,
         193.0629, 192.2048, 192.9100],
        [192.1960, 193.0291, 193.0249, 192.2073, 193.0340, 192.3643, 193.0669,
         193.0290, 192.1717, 192.8756],
        [192.4569, 193.2863, 193.2810, 192.4638, 193.2907, 192.6237, 193.3223,
         193.2860, 192.4297, 193.1343],
        [192.0122, 192.8445, 192.8409, 192.0227, 192.8496, 192.1794, 192.8828,
         192.8454, 191.9873, 192.6900],
        [192.0277, 192.8595, 192.8556, 192.0373, 192.8640, 192.1947, 192.8972,
         192.8602, 192.0019, 192.7055],
        [192.1122, 192.9447, 192.9410, 192.1231, 192.9499, 192.2796, 192.9828,
         192.9453, 192.0876, 192.7903],
        [192.3820, 193.2122, 193.2070, 192.3897, 193.2166, 192.5491, 193.2484,
         193.2120, 192.3553, 193.0599],
        [192.2002, 193

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.8420, 192.9725, 193.0106, 192.4386, 192.8098, 192.8465, 192.9492,
         192.8410, 193.0720, 193.1492],
        [192.5275, 192.6608, 192.6982, 192.1241, 192.4984, 192.5331, 192.6356,
         192.5278, 192.7593, 192.8367],
        [192.5900, 192.7258, 192.7586, 192.1852, 192.5632, 192.5925, 192.6966,
         192.5900, 192.8218, 192.8989],
        [192.7299, 192.8644, 192.8988, 192.3265, 192.7016, 192.7341, 192.8372,
         192.7297, 192.9604, 193.0384],
        [192.8086, 192.9434, 192.9762, 192.4045, 192.7804, 192.8114, 192.9149,
         192.8081, 193.0390, 193.1164],
        [192.7728, 192.9050, 192.9405, 192.3688, 192.7437, 192.7737, 192.8799,
         192.7719, 193.0050, 193.0810],
        [192.6185, 192.7520, 192.7875, 192.2138, 192.5887, 192.6230, 192.7249,
         192.6183, 192.8494, 192.9264],
        [192.5946, 192

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.7498, 193.0580, 192.7815, 193.0618, 192.5502, 192.8810, 193.0572,
         193.1514, 192.1345, 193.7880],
        [192.6418, 192.9479, 192.6714, 192.9524, 192.4408, 192.7702, 192.9478,
         193.0409, 192.0243, 193.6786],
        [192.5609, 192.8647, 192.5880, 192.8693, 192.3583, 192.6877, 192.8649,
         192.9580, 191.9416, 193.5966],
        [192.4687, 192.7755, 192.4980, 192.7793, 192.2681, 192.5997, 192.7748,
         192.8693, 191.8527, 193.5068],
        [192.7103, 193.0167, 192.7406, 193.0221, 192.5099, 192.8397, 193.0163,
         193.1098, 192.0938, 193.7477],
        [192.5092, 192.8129, 192.5359, 192.8172, 192.3065, 192.6368, 192.8129,
         192.9064, 191.8902, 193.5453],
        [192.6386, 192.9443, 192.6679, 192.9492, 192.4373, 192.7660, 192.9444,
         193.0371, 192.0205, 193.6748],
        [192.7428, 193.0486, 192.7723, 193.0524, 192.54

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.7255, 192.4865, 191.7815, 192.7139, 192.2290, 193.3703, 191.9738,
         192.8673, 192.2751, 192.7323],
        [192.9645, 192.7247, 192.0206, 192.9519, 192.4679, 193.6099, 192.2145,
         193.1061, 192.5151, 192.9712],
        [193.0643, 192.8256, 192.1221, 193.0501, 192.5692, 193.7092, 192.3151,
         193.2042, 192.6164, 193.0707],
        [193.0562, 192.8159, 192.1131, 193.0435, 192.5601, 193.7014, 192.3055,
         193.1987, 192.6066, 193.0624],
        [193.1235, 192.8845, 192.1794, 193.1112, 192.6262, 193.7687, 192.3727,
         193.2659, 192.6740, 193.1307],
        [193.0545, 192.8160, 192.1123, 193.0415, 192.5590, 193.6991, 192.3035,
         193.1961, 192.6056, 193.0611],
        [192.9141, 192.6761, 191.9710, 192.9021, 192.4180, 193.5587, 192.1625,
         193.0557, 192.4646, 192.9211],
        [192.7818, 192

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.6734, 193.3156, 192.7353, 191.9448, 192.5775, 192.7273, 192.8920,
         192.8671, 193.3384, 192.3846],
        [193.6849, 193.3271, 192.7466, 191.9564, 192.5891, 192.7390, 192.9034,
         192.8783, 193.3492, 192.3960],
        [193.8729, 193.5150, 192.9336, 192.1428, 192.7764, 192.9250, 193.0907,
         193.0656, 193.5375, 192.5826],
        [193.7827, 193.4256, 192.8428, 192.0520, 192.6864, 192.8356, 193.0005,
         192.9757, 193.4467, 192.4914],
        [193.7961, 193.4387, 192.8568, 192.0666, 192.7001, 192.8495, 193.0139,
         192.9890, 193.4596, 192.5059],
        [193.6237, 193.2652, 192.6854, 191.8951, 192.5272, 192.6754, 192.8414,
         192.8170, 193.2898, 192.3342],
        [193.9193, 193.5620, 192.9801, 192.1890, 192.8232, 192.9730, 193.1377,
         193.1123, 193.5832, 192.6293],
        [193.6961, 193

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.1980, 192.5726, 192.9135, 191.9890, 192.2469, 193.5578, 192.6190,
         192.6667, 193.2414, 192.4997],
        [192.2759, 192.6513, 192.9919, 192.0676, 192.3246, 193.6360, 192.6975,
         192.7450, 193.3199, 192.5779],
        [191.9565, 192.3329, 192.6743, 191.7490, 192.0056, 193.3187, 192.3793,
         192.4270, 193.0027, 192.2598],
        [192.2185, 192.5909, 192.9333, 192.0070, 192.2654, 193.5777, 192.6394,
         192.6859, 193.2630, 192.5188],
        [192.0418, 192.4173, 192.7601, 191.8339, 192.0904, 193.4043, 192.4648,
         192.5123, 193.0884, 192.3453],
        [192.3492, 192.7191, 193.0633, 192.1366, 192.3972, 193.7078, 192.7681,
         192.8161, 193.3908, 192.6497],
        [192.1397, 192.5153, 192.8559, 191.9314, 192.1885, 193.5002, 192.5616,
         192.6090, 193.1843, 192.4418],
        [192.0275, 192

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.4911, 193.2472, 192.6674, 191.8458, 192.5291, 192.5750, 192.4250,
         192.5890, 192.5353, 192.4301],
        [192.7661, 193.5210, 192.9404, 192.1195, 192.8005, 192.8479, 192.6982,
         192.8623, 192.8101, 192.7041],
        [192.6201, 193.3784, 192.7988, 191.9782, 192.6595, 192.7069, 192.5559,
         192.7195, 192.6657, 192.5622],
        [192.4388, 193.1945, 192.6145, 191.7939, 192.4748, 192.5224, 192.3721,
         192.5356, 192.4829, 192.3784],
        [192.5559, 193.3098, 192.7295, 191.9084, 192.5894, 192.6370, 192.4871,
         192.6515, 192.5994, 192.4931],
        [192.6686, 193.4260, 192.8457, 192.0256, 192.7059, 192.7538, 192.6035,
         192.7664, 192.7137, 192.6101],
        [192.5228, 193.2791, 192.6996, 191.8784, 192.5602, 192.6074, 192.4566,
         192.6210, 192.5675, 192.4625],
        [192.7259, 193

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.4945, 192.2633, 192.9498, 192.8602, 192.0877, 192.6721, 193.6454,
         192.5962, 192.9141, 192.7630],
        [193.3314, 192.0988, 192.7885, 192.6978, 191.9227, 192.5093, 193.4820,
         192.4323, 192.7521, 192.5991],
        [193.5352, 192.3034, 192.9911, 192.9012, 192.1275, 192.7128, 193.6855,
         192.6364, 192.9555, 192.8029],
        [193.4813, 192.2496, 192.9378, 192.8482, 192.0729, 192.6587, 193.6307,
         192.5835, 192.9014, 192.7508],
        [193.3855, 192.1534, 192.8415, 192.7516, 191.9785, 192.5630, 193.5363,
         192.4876, 192.8060, 192.6543],
        [193.2527, 192.0203, 192.7100, 192.6196, 191.8437, 192.4304, 193.4030,
         192.3544, 192.6733, 192.5217],
        [193.4087, 192.1766, 192.8660, 192.7755, 191.9988, 192.5866, 193.5585,
         192.5094, 192.8287, 192.6768],
        [193.5166, 192.2856, 192.9716, 192.8821, 192.11

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.6844, 192.2926, 193.4239, 193.5048, 192.7495, 192.7709, 192.4680,
         192.5005, 192.9281, 192.5751],
        [192.7178, 192.3291, 193.4564, 193.5405, 192.7861, 192.8054, 192.5009,
         192.5338, 192.9593, 192.6070],
        [192.8687, 192.4769, 193.6073, 193.6894, 192.9313, 192.9559, 192.6521,
         192.6842, 193.1100, 192.7597],
        [192.7048, 192.3129, 193.4446, 193.5254, 192.7703, 192.7921, 192.4889,
         192.5208, 192.9485, 192.5949],
        [192.7525, 192.3604, 193.4913, 193.5728, 192.8156, 192.8389, 192.5354,
         192.5681, 192.9947, 192.6439],
        [192.6537, 192.2648, 193.3927, 193.4763, 192.7224, 192.7410, 192.4367,
         192.4697, 192.8958, 192.5429],
        [192.8302, 192.4384, 193.5692, 193.6511, 192.8936, 192.9178, 192.6139,
         192.6458, 193.0718, 192.7207],
        [192.5690, 192

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.8961, 191.8029, 192.7267, 192.3129, 192.4343, 192.4142, 191.9022,
         192.2172, 192.5226, 192.5840],
        [191.9514, 191.8587, 192.7839, 192.3695, 192.4912, 192.4714, 191.9596,
         192.2734, 192.5798, 192.6403],
        [191.6913, 191.6010, 192.5250, 192.1107, 192.2319, 192.2116, 191.7000,
         192.0144, 192.3186, 192.3810],
        [191.9882, 191.8949, 192.8179, 192.4042, 192.5254, 192.5058, 191.9928,
         192.3084, 192.6123, 192.6739],
        [191.9658, 191.8735, 192.7984, 192.3845, 192.5051, 192.4861, 191.9729,
         192.2881, 192.5929, 192.6544],
        [192.0430, 191.9504, 192.8760, 192.4623, 192.5825, 192.5640, 192.0502,
         192.3657, 192.6709, 192.7323],
        [191.6759, 191.5863, 192.5107, 192.0960, 192.2172, 192.1971, 191.6855,
         191.9995, 192.3035, 192.3659],
        [191.9325, 191

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.4744, 192.7181, 193.4570, 192.8856, 192.9051, 192.1479, 192.4857,
         192.8150, 192.8104, 192.3052],
        [192.3709, 192.6176, 193.3569, 192.7824, 192.8020, 192.0454, 192.3836,
         192.7150, 192.7092, 192.2033],
        [192.5421, 192.7898, 193.5289, 192.9532, 192.9739, 192.2205, 192.5547,
         192.8864, 192.8819, 192.3765],
        [192.3814, 192.6270, 193.3668, 192.7935, 192.8125, 192.0568, 192.3926,
         192.7248, 192.7191, 192.2140],
        [192.4297, 192.6734, 193.4125, 192.8412, 192.8604, 192.1032, 192.4408,
         192.7705, 192.7656, 192.2605],
        [192.2405, 192.4865, 193.2255, 192.6521, 192.6713, 191.9138, 192.2529,
         192.5833, 192.5781, 192.0720],
        [192.2994, 192.5474, 193.2861, 192.7107, 192.7309, 191.9759, 192.3128,
         192.6437, 192.6391, 192.1328],
        [192.1560, 192

tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.2547, 192.9286, 192.5258, 192.6111, 192.6965, 192.5586, 192.3853,
         192.5622, 192.5013, 193.5084],
        [191.9693, 192.6451, 192.2416, 192.3270, 192.4146, 192.2752, 192.0991,
         192.2805, 192.2176, 193.2229],
        [192.0550, 192.7309, 192.3274, 192.4120, 192.5002, 192.3611, 192.1843,
         192.3657, 192.3029, 193.3081],
        [192.2340, 192.9090, 192.5049, 192.5907, 192.6779, 192.5384, 192.3635,
         192.5424, 192.4810, 193.4881],
        [192.1925, 192.8698, 192.4671, 192.5484, 192.6363, 192.4998, 192.3232,
         192.5004, 192.4384, 193.4462],
        [192.1451, 192.8220, 192.4190, 192.5015, 192.5890, 192.4520, 192.2758,
         192.4535, 192.3914, 193.3990],
        [192.3518, 193.0280, 192.6248, 192.7066, 192.7951, 192.6580, 192.4814,
         192.6588, 192.5973, 193.6045],
        [192.1177, 192.7950, 192.3922, 192.4740, 192.56

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[193.3961, 192.4400, 191.7369, 191.6731, 192.4397, 192.2546, 192.5922,
         192.1444, 192.4531, 192.6269],
        [193.3803, 192.4234, 191.7213, 191.6570, 192.4239, 192.2381, 192.5761,
         192.1283, 192.4375, 192.6115],
        [193.4814, 192.5242, 191.8227, 191.7563, 192.5254, 192.3368, 192.6768,
         192.2297, 192.5390, 192.7136],
        [193.5051, 192.5493, 191.8453, 191.7800, 192.5477, 192.3617, 192.7006,
         192.2534, 192.5611, 192.7361],
        [193.7197, 192.7634, 192.0603, 191.9952, 192.7612, 192.5764, 192.9146,
         192.4683, 192.7741, 192.9489],
        [193.5639, 192.6077, 191.9046, 191.8408, 192.6055, 192.4223, 192.7593,
         192.3120, 192.6187, 192.7933],
        [193.4187, 192.4619, 191.7592, 191.6937, 192.4608, 192.2748, 192.6140,
         192.1659, 192.4747, 192.6503],
        [193.4834, 192.5277, 191.8229, 191.7584, 192.52

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.5826, 191.9307, 192.5054, 191.9547, 192.3461, 192.3401, 192.9145,
         191.7538, 192.2770, 192.6747],
        [192.5583, 191.9065, 192.4804, 191.9312, 192.3213, 192.3163, 192.8896,
         191.7306, 192.2552, 192.6508],
        [192.6459, 191.9950, 192.5664, 192.0190, 192.4081, 192.4037, 192.9762,
         191.8208, 192.3432, 192.7394],
        [192.7274, 192.0763, 192.6497, 192.0990, 192.4910, 192.4844, 193.0593,
         191.9003, 192.4201, 192.8201],
        [192.5896, 191.9384, 192.5103, 191.9623, 192.3519, 192.3470, 192.9201,
         191.7637, 192.2854, 192.6830],
        [192.4801, 191.8280, 192.4021, 191.8525, 192.2430, 192.2374, 192.8115,
         191.6517, 192.1744, 192.5728],
        [192.6571, 192.0067, 192.5797, 192.0294, 192.4211, 192.4157, 192.9877,
         191.8302, 192.3533, 192.7500],
        [192.6125, 191

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.9902, 192.4008, 192.3228, 192.8726, 192.3226, 192.4114, 192.0374,
         192.2338, 192.2161, 192.2396],
        [192.9560, 192.3662, 192.2879, 192.8390, 192.2868, 192.3773, 192.0022,
         192.1989, 192.1829, 192.2020],
        [192.9810, 192.3903, 192.3129, 192.8628, 192.3136, 192.4010, 192.0279,
         192.2243, 192.2056, 192.2309],
        [193.0760, 192.4856, 192.4083, 192.9579, 192.4081, 192.4963, 192.1225,
         192.3195, 192.3006, 192.3236],
        [193.2732, 192.6818, 192.6050, 193.1562, 192.6035, 192.6939, 192.3197,
         192.5164, 192.4999, 192.5209],
        [193.1777, 192.5857, 192.5085, 193.0609, 192.5073, 192.5979, 192.2235,
         192.4202, 192.4043, 192.4241],
        [193.0714, 192.4778, 192.4012, 192.9534, 192.4011, 192.4897, 192.1160,
         192.3132, 192.2956, 192.3160],
        [192.9331, 192

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.4921, 192.0966, 192.2019, 192.3391, 193.2955, 193.2183, 192.1300,
         191.6639, 192.3166, 192.2558],
        [192.5412, 192.1458, 192.2508, 192.3887, 193.3441, 193.2695, 192.1798,
         191.7166, 192.3649, 192.3040],
        [192.5017, 192.1052, 192.2111, 192.3482, 193.3047, 193.2292, 192.1396,
         191.6760, 192.3245, 192.2646],
        [192.7926, 192.3974, 192.5029, 192.6394, 193.5964, 193.5195, 192.4315,
         191.9678, 192.6173, 192.5563],
        [192.4703, 192.0748, 192.1796, 192.3179, 193.2729, 193.1983, 192.1085,
         191.6449, 192.2938, 192.2328],
        [192.5451, 192.1497, 192.2549, 192.3915, 193.3492, 193.2717, 192.1833,
         191.7169, 192.3695, 192.3096],
        [192.5424, 192.1474, 192.2523, 192.3899, 193.3459, 193.2685, 192.1806,
         191.7145, 192.3675, 192.3060],
        [192.5220, 192.1266, 192.2316, 192.3698, 193.32

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.8444, 191.9004, 191.6003, 191.9849, 192.0067, 191.9507, 191.3082,
         191.8174, 192.1496, 192.1890],
        [192.0583, 192.1149, 191.8141, 192.1993, 192.2205, 192.1660, 191.5219,
         192.0320, 192.3638, 192.4035],
        [192.2303, 192.2880, 191.9848, 192.3700, 192.3896, 192.3381, 191.6939,
         192.2023, 192.5349, 192.5772],
        [191.9830, 192.0401, 191.7385, 192.1232, 192.1439, 192.0900, 191.4468,
         191.9554, 192.2882, 192.3290],
        [192.3586, 192.4147, 192.1129, 192.4980, 192.5184, 192.4662, 191.8217,
         192.3306, 192.6633, 192.7042],
        [192.0030, 192.0584, 191.7583, 192.1434, 192.1650, 192.1098, 191.4664,
         191.9761, 192.3079, 192.3474],
        [192.1714, 192.2298, 191.9272, 192.3119, 192.3320, 192.2798, 191.6351,
         192.1442, 192.4773, 192.5186],
        [192.2082, 192.2631, 191.9616, 192.3467, 192.36

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.5873, 192.3455, 192.5039, 192.4236, 191.9160, 192.5638, 192.4506,
         192.8246, 192.1129, 192.6622],
        [192.3367, 192.0939, 192.2516, 192.1718, 191.6649, 192.3141, 192.1986,
         192.5741, 191.8624, 192.4099],
        [192.2947, 192.0515, 192.2092, 192.1296, 191.6220, 192.2711, 192.1564,
         192.5315, 191.8203, 192.3676],
        [192.4471, 192.2049, 192.3638, 192.2835, 191.7764, 192.4258, 192.3104,
         192.6856, 191.9734, 192.5220],
        [192.3178, 192.0747, 192.2322, 192.1525, 191.6446, 192.2932, 192.1794,
         192.5540, 191.8431, 192.3909],
        [192.1534, 191.9108, 192.0681, 191.9865, 191.4783, 192.1286, 192.0144,
         192.3900, 191.6789, 192.2272],
        [192.2961, 192.0533, 192.2123, 192.1322, 191.6243, 192.2742, 192.1590,
         192.5341, 191.8225, 192.3708],
        [192.0103, 191.7675, 191.9275, 191.8452, 191.33

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.6588, 192.4289, 192.5339, 192.0052, 192.5747, 192.4919, 192.3396,
         193.2426, 191.8129, 192.0163],
        [192.4127, 192.1820, 192.2884, 191.7580, 192.3275, 192.2432, 192.0942,
         192.9960, 191.5653, 191.7694],
        [192.5826, 192.3522, 192.4575, 191.9267, 192.4971, 192.4135, 192.2639,
         193.1656, 191.7353, 191.9391],
        [192.3198, 192.0894, 192.1934, 191.6620, 192.2330, 192.1485, 191.9999,
         192.9019, 191.4718, 191.6746],
        [192.4395, 192.2092, 192.3129, 191.7823, 192.3533, 192.2697, 192.1200,
         193.0219, 191.5919, 191.7947],
        [192.3924, 192.1620, 192.2668, 191.7346, 192.3058, 192.2212, 192.0737,
         192.9747, 191.5443, 191.7480],
        [192.3797, 192.1491, 192.2528, 191.7239, 192.2941, 192.2109, 192.0596,
         192.9623, 191.5323, 191.7345],
        [192.5445, 192.3146, 192.4204, 191.8891, 192.45

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.3600, 192.4933, 191.9845, 193.2711, 192.3776, 192.6159, 192.0458,
         193.4206, 192.4582, 192.6027],
        [192.0080, 192.1409, 191.6311, 192.9193, 192.0237, 192.2642, 191.6927,
         193.0701, 192.1059, 192.2510],
        [192.1194, 192.2525, 191.7432, 193.0305, 192.1366, 192.3746, 191.8044,
         193.1811, 192.2168, 192.3620],
        [192.1915, 192.3238, 191.8145, 193.1022, 192.2084, 192.4469, 191.8765,
         193.2523, 192.2890, 192.4336],
        [192.2138, 192.3457, 191.8366, 193.1243, 192.2306, 192.4688, 191.8987,
         193.2747, 192.3110, 192.4557],
        [192.1148, 192.2496, 191.7404, 193.0267, 192.1329, 192.3704, 191.8006,
         193.1776, 192.2125, 192.3586],
        [192.0261, 192.1629, 191.6528, 192.9392, 192.0446, 192.2845, 191.7121,
         193.0890, 192.1257, 192.2713],
        [192.1048, 192

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.1893, 191.3655, 191.8036, 191.8759, 192.1850, 191.8944, 192.0084,
         192.0329, 191.5343, 191.8754],
        [192.1750, 191.3510, 191.7888, 191.8615, 192.1704, 191.8797, 191.9935,
         192.0182, 191.5195, 191.8608],
        [192.1206, 191.2966, 191.7341, 191.8069, 192.1159, 191.8252, 191.9381,
         191.9628, 191.4646, 191.8067],
        [192.5401, 191.7154, 192.1537, 192.2281, 192.5342, 192.2443, 192.3568,
         192.3809, 191.8854, 192.2265],
        [192.3620, 191.5383, 191.9749, 192.0492, 192.3570, 192.0671, 192.1785,
         192.2049, 191.7067, 192.0479],
        [192.3706, 191.5468, 191.9861, 192.0581, 192.3667, 192.0763, 192.1911,
         192.2149, 191.7171, 192.0574],
        [192.2273, 191.4029, 191.8416, 191.9141, 192.2224, 191.9318, 192.0456,
         192.0692, 191.5721, 191.9137],
        [192.3238, 191.4996, 191.9387, 192.0111, 192.31

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.4663, 192.1142, 192.7025, 192.4424, 192.3141, 192.6904, 192.6190,
         192.6140, 193.1954, 192.3120],
        [192.2344, 191.8828, 192.4710, 192.2101, 192.0824, 192.4593, 192.3877,
         192.3834, 192.9640, 192.0809],
        [191.8896, 191.5377, 192.1266, 191.8663, 191.7377, 192.1163, 192.0432,
         192.0385, 192.6189, 191.7340],
        [192.0326, 191.6796, 192.2681, 192.0074, 191.8815, 192.2567, 192.1841,
         192.1806, 192.7603, 191.8757],
        [192.2744, 191.9224, 192.5109, 192.2507, 192.1213, 192.4998, 192.4284,
         192.4219, 193.0041, 192.1201],
        [192.2199, 191.8683, 192.4569, 192.1968, 192.0691, 192.4455, 192.3726,
         192.3696, 192.9492, 192.0657],
        [192.1661, 191.8133, 192.4016, 192.1402, 192.0136, 192.3896, 192.3186,
         192.3135, 192.8946, 192.0107],
        [192.1675, 191

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.0922, 191.7195, 191.9864, 192.6053, 192.0770, 192.4720, 192.0412,
         191.7314, 192.3017, 192.2891],
        [191.9047, 191.5299, 191.7966, 192.4175, 191.8867, 192.2846, 191.8544,
         191.5430, 192.1156, 192.1019],
        [192.0407, 191.6684, 191.9341, 192.5553, 192.0255, 192.4220, 191.9916,
         191.6805, 192.2493, 192.2374],
        [191.7718, 191.3991, 191.6655, 192.2848, 191.7568, 192.1520, 191.7215,
         191.4098, 191.9821, 191.9695],
        [192.1012, 191.7261, 191.9931, 192.6132, 192.0824, 192.4813, 192.0511,
         191.7401, 192.3116, 192.2984],
        [192.0923, 191.7190, 191.9847, 192.6075, 192.0757, 192.4735, 192.0434,
         191.7323, 192.3013, 192.2886],
        [192.1369, 191.7635, 192.0294, 192.6518, 192.1203, 192.5181, 192.0879,
         191.7770, 192.3457, 192.3332],
        [191.9675, 191

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.8163, 191.4051, 192.2440, 192.7886, 192.0333, 192.1451, 192.3210,
         192.2661, 192.1009, 192.0034],
        [191.8330, 191.4213, 192.2623, 192.8062, 192.0500, 192.1630, 192.3383,
         192.2833, 192.1169, 192.0218],
        [191.8468, 191.4352, 192.2757, 192.8186, 192.0627, 192.1757, 192.3508,
         192.2948, 192.1313, 192.0340],
        [191.7170, 191.3071, 192.1457, 192.6914, 191.9336, 192.0440, 192.2221,
         192.1659, 192.0024, 191.9051],
        [191.8480, 191.4364, 192.2762, 192.8195, 192.0643, 192.1771, 192.3521,
         192.2966, 192.1326, 192.0347],
        [191.7323, 191.3219, 192.1613, 192.7071, 191.9495, 192.0605, 192.2380,
         192.1826, 192.0170, 191.9214],
        [191.7468, 191.3353, 192.1755, 192.7193, 191.9639, 192.0764, 192.2520,
         192.1970, 192.0309, 191.9346],
        [191.8721, 191

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.5506, 191.8771, 192.3855, 192.1989, 191.6222, 191.8902, 191.9119,
         192.6883, 192.4068, 193.1133],
        [192.1487, 191.4726, 191.9815, 191.7965, 191.2196, 191.4855, 191.5095,
         192.2856, 192.0033, 192.7103],
        [192.5357, 191.8612, 192.3698, 192.1839, 191.6068, 191.8747, 191.8965,
         192.6730, 192.3915, 193.0979],
        [192.3799, 191.7062, 192.2143, 192.0300, 191.4528, 191.7195, 191.7417,
         192.5167, 192.2372, 192.9428],
        [192.3224, 191.6466, 192.1564, 191.9692, 191.3917, 191.6591, 191.6829,
         192.4589, 192.1764, 192.8833],
        [192.2108, 191.5379, 192.0470, 191.8585, 191.2818, 191.5484, 191.5731,
         192.3474, 192.0659, 192.7733],
        [192.3805, 191.7041, 192.2131, 192.0281, 191.4506, 191.7180, 191.7406,
         192.5174, 192.2352, 192.9416],
        [192.3183, 191

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.6339, 191.8596, 191.2505, 191.5397, 191.9453, 191.9868, 192.1692,
         192.9652, 191.9285, 191.3414],
        [192.7729, 191.9976, 191.3885, 191.6786, 192.0831, 192.1275, 192.3089,
         193.1054, 192.0688, 191.4796],
        [192.7486, 191.9761, 191.3646, 191.6559, 192.0597, 192.1057, 192.2854,
         193.0815, 192.0447, 191.4565],
        [192.8082, 192.0328, 191.4235, 191.7136, 192.1184, 192.1628, 192.3438,
         193.1403, 192.1036, 191.5153],
        [192.6408, 191.8737, 191.2598, 191.5505, 191.9554, 191.9997, 192.1800,
         192.9746, 191.9382, 191.3509],
        [192.6496, 191.8741, 191.2653, 191.5553, 191.9598, 192.0033, 192.1850,
         192.9815, 191.9447, 191.3559],
        [192.4908, 191.7174, 191.1077, 191.3978, 191.8021, 191.8450, 192.0273,
         192.8232, 191.7865, 191.1972],
        [192.6408, 191

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.3553, 192.0040, 192.1698, 191.8167, 192.1008, 191.9607, 192.3337,
         192.0328, 192.3205, 192.0184],
        [192.3360, 191.9863, 192.1495, 191.7982, 192.0833, 191.9418, 192.3151,
         192.0144, 192.3015, 191.9982],
        [192.4756, 192.1262, 192.2895, 191.9375, 192.2221, 192.0817, 192.4554,
         192.1541, 192.4412, 192.1375],
        [192.2466, 191.8936, 192.0578, 191.7063, 191.9907, 191.8499, 192.2243,
         191.9231, 192.2102, 191.9060],
        [192.1569, 191.8041, 191.9713, 191.6182, 191.9024, 191.7615, 192.1337,
         191.8337, 192.1217, 191.8209],
        [192.2276, 191.8751, 192.0403, 191.6873, 191.9708, 191.8310, 192.2058,
         191.9040, 192.1918, 191.8880],
        [192.4910, 192.1413, 192.3053, 191.9531, 192.2378, 192.0973, 192.4704,
         192.1695, 192.4566, 192.1537],
        [192.2768, 191

       grad_fn=<CdistBackward0>)
tensor([3, 3, 1, 3, 3, 3, 1, 3, 3, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.3554, 192.0033, 192.2105, 192.4453, 192.3342, 192.4542, 192.0750,
         192.2725, 191.7041, 192.0611],
        [192.1537, 191.7994, 192.0084, 192.2421, 192.1319, 192.2511, 191.8722,
         192.0703, 191.4996, 191.8563],
        [192.0651, 191.7117, 191.9198, 192.1554, 192.0459, 192.1642, 191.7841,
         191.9816, 191.4130, 191.7700],
        [191.8284, 191.4725, 191.6824, 191.9172, 191.8092, 191.9261, 191.5462,
         191.7440, 191.1732, 191.5298],
        [192.2029, 191.8495, 192.0576, 192.2919, 192.1809, 192.3008, 191.9218,
         192.1200, 191.5502, 191.9073],
        [192.2033, 191.8499, 192.0591, 192.2940, 192.1858, 192.3032, 191.9225,
         192.1192, 191.5503, 191.9065],
        [192.2031, 191.8500, 192.0578, 192.2931, 192.1831, 192.3017, 191.9223,
         192.1197, 191.5508, 191.9077],
        [192.1476, 191

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.8932, 192.1932, 191.8461, 191.9806, 192.1248, 192.0432, 191.7023,
         192.6275, 192.2063, 192.7556],
        [192.1085, 192.4089, 192.0617, 192.1957, 192.3399, 192.2572, 191.9183,
         192.8425, 192.4218, 192.9708],
        [192.1003, 192.4007, 192.0541, 192.1869, 192.3317, 192.2484, 191.9103,
         192.8340, 192.4136, 192.9626],
        [191.7487, 192.0480, 191.7026, 191.8351, 191.9800, 191.8989, 191.5583,
         192.4827, 192.0617, 192.6113],
        [191.9183, 192.2198, 191.8715, 192.0049, 192.1505, 192.0667, 191.7268,
         192.6531, 192.2314, 192.7804],
        [191.7206, 192.0217, 191.6758, 191.8062, 191.9525, 191.8694, 191.5309,
         192.4543, 192.0342, 192.5831],
        [191.8589, 192.1592, 191.8141, 191.9448, 192.0902, 192.0078, 191.6702,
         192.5921, 192.1727, 192.7216],
        [191.7824, 192

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.8416, 191.1585, 191.8275, 191.8508, 191.7612, 191.8884, 191.9816,
         191.9758, 191.2693, 191.9798],
        [191.9446, 191.2604, 191.9313, 191.9514, 191.8619, 191.9906, 192.0827,
         192.0773, 191.3736, 192.0819],
        [192.2137, 191.5317, 192.2007, 192.2222, 192.1330, 192.2614, 192.3531,
         192.3478, 191.6444, 192.3520],
        [192.0902, 191.4070, 192.0770, 192.0979, 192.0084, 192.1371, 192.2291,
         192.2234, 191.5202, 192.2281],
        [191.8963, 191.2101, 191.8826, 191.9016, 191.8135, 191.9418, 192.0329,
         192.0300, 191.3240, 192.0326],
        [191.9471, 191.2611, 191.9331, 191.9536, 191.8645, 191.9936, 192.0858,
         192.0801, 191.3757, 192.0848],
        [191.9147, 191.2282, 191.9007, 191.9207, 191.8317, 191.9609, 192.0530,
         192.0475, 191.3431, 192.0522],
        [191.9161, 191

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.7091, 191.7546, 191.4926, 192.3326, 191.7332, 191.4789, 191.6501,
         191.5963, 192.6099, 191.5983],
        [192.0193, 192.0655, 191.8022, 192.6428, 192.0454, 191.7882, 191.9608,
         191.9075, 192.9191, 191.9087],
        [191.8943, 191.9393, 191.6753, 192.5148, 191.9178, 191.6617, 191.8333,
         191.7801, 192.7936, 191.7817],
        [191.8345, 191.8796, 191.6167, 192.4560, 191.8577, 191.6030, 191.7740,
         191.7205, 192.7345, 191.7224],
        [191.8064, 191.8518, 191.5891, 192.4288, 191.8304, 191.5754, 191.7466,
         191.6931, 192.7066, 191.6949],
        [191.9661, 192.0111, 191.7465, 192.5863, 191.9897, 191.7332, 191.9051,
         191.8517, 192.8650, 191.8533],
        [191.8512, 191.8966, 191.6328, 192.4727, 191.8756, 191.6191, 191.7907,
         191.7378, 192.7507, 191.7392],
        [191.8544, 191

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.7077, 191.8084, 191.8281, 192.0904, 191.8126, 191.9408, 191.6725,
         192.2646, 192.0250, 192.2083],
        [191.7660, 191.8667, 191.8853, 192.1445, 191.8718, 191.9947, 191.7298,
         192.3225, 192.0828, 192.2634],
        [191.6883, 191.7889, 191.8080, 192.0683, 191.7932, 191.9183, 191.6521,
         192.2452, 192.0054, 192.1867],
        [191.6871, 191.7878, 191.8073, 192.0660, 191.7927, 191.9160, 191.6518,
         192.2449, 192.0056, 192.1857],
        [191.6957, 191.7971, 191.8169, 192.0776, 191.7985, 191.9268, 191.6607,
         192.2553, 192.0145, 192.1959],
        [191.4535, 191.5531, 191.5731, 191.8343, 191.5592, 191.6839, 191.4163,
         192.0095, 191.7710, 191.9522],
        [191.6503, 191.7516, 191.7715, 192.0324, 191.7530, 191.8815, 191.6152,
         192.2099, 191.9693, 192.1507],
        [191.6698, 191

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.9583, 191.6586, 191.7851, 192.1888, 191.9558, 191.3168, 191.6739,
         191.8725, 192.0556, 191.4591],
        [192.0635, 191.7639, 191.8931, 192.2924, 192.0632, 191.4229, 191.7811,
         191.9800, 192.1611, 191.5652],
        [191.9404, 191.6406, 191.7690, 192.1697, 191.9394, 191.2992, 191.6573,
         191.8563, 192.0379, 191.4417],
        [191.7370, 191.4388, 191.5648, 191.9645, 191.7341, 191.0930, 191.4514,
         191.6516, 191.8322, 191.2373],
        [192.2091, 191.9119, 192.0393, 192.4361, 192.2054, 191.5672, 191.9244,
         192.1236, 192.3043, 191.7100],
        [191.9321, 191.6347, 191.7614, 192.1589, 191.9290, 191.2888, 191.6470,
         191.8469, 192.0273, 191.4327],
        [191.9651, 191.6662, 191.7928, 192.1947, 191.9626, 191.3234, 191.6807,
         191.8796, 192.0622, 191.4661],
        [192.0000, 191

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.6860, 191.7124, 191.9228, 191.6011, 191.7972, 192.0370, 191.6495,
         192.3552, 191.5995, 191.5802],
        [192.6947, 191.7193, 191.9311, 191.6090, 191.8042, 192.0466, 191.6567,
         192.3640, 191.6078, 191.5875],
        [192.8873, 191.9153, 192.1239, 191.8024, 191.9987, 192.2394, 191.8504,
         192.5548, 191.8008, 191.7816],
        [192.8087, 191.8359, 192.0455, 191.7238, 191.9199, 192.1607, 191.7720,
         192.4764, 191.7220, 191.7031],
        [192.7417, 191.7677, 191.9787, 191.6561, 191.8517, 192.0942, 191.7039,
         192.4093, 191.6547, 191.6348],
        [192.7917, 191.8174, 192.0284, 191.7061, 191.9015, 192.1441, 191.7538,
         192.4595, 191.7047, 191.6849],
        [192.6058, 191.6317, 191.8434, 191.5206, 191.7162, 191.9577, 191.5686,
         192.2743, 191.5192, 191.4993],
        [192.6786, 191.7030, 191.9148, 191.5932, 191.78

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.1261, 191.1399, 191.4484, 191.7921, 191.3914, 191.7997, 192.2332,
         191.7093, 190.9629, 191.7872],
        [191.1912, 191.2025, 191.5102, 191.8542, 191.4566, 191.8616, 192.2955,
         191.7706, 191.0269, 191.8496],
        [191.3595, 191.3736, 191.6816, 192.0231, 191.6239, 192.0307, 192.4633,
         191.9392, 191.1990, 192.0191],
        [191.0414, 191.0546, 191.3640, 191.7063, 191.3060, 191.7143, 192.1473,
         191.6233, 190.8784, 191.7021],
        [191.2643, 191.2771, 191.5853, 191.9268, 191.5286, 191.9344, 192.3672,
         191.8427, 191.1028, 191.9230],
        [191.5301, 191.5441, 191.8503, 192.1938, 191.7958, 192.2012, 192.6344,
         192.1099, 191.3689, 192.1890],
        [191.2056, 191.2175, 191.5255, 191.8681, 191.4701, 191.8748, 192.3084,
         191.7839, 191.0426, 191.8637],
        [191.2792, 191

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.7965, 191.4221, 191.9534, 191.8587, 191.9780, 191.8157, 192.1628,
         191.8907, 191.9910, 192.9274],
        [191.5563, 191.1812, 191.7142, 191.6176, 191.7386, 191.5750, 191.9229,
         191.6506, 191.7520, 192.6865],
        [191.3539, 190.9798, 191.5116, 191.4158, 191.5383, 191.3726, 191.7236,
         191.4499, 191.5525, 192.4851],
        [191.4048, 191.0318, 191.5622, 191.4676, 191.5899, 191.4242, 191.7752,
         191.5013, 191.6039, 192.5369],
        [191.7888, 191.4152, 191.9458, 191.8515, 191.9711, 191.8086, 192.1557,
         191.8834, 191.9841, 192.9203],
        [191.6507, 191.2746, 191.8083, 191.7117, 191.8314, 191.6690, 192.0164,
         191.7444, 191.8448, 192.7808],
        [191.6663, 191.2947, 191.8237, 191.7301, 191.8516, 191.6866, 192.0355,
         191.7626, 191.8647, 192.7986],
        [191.3856, 191

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.5409, 192.3828, 191.4433, 191.3668, 191.4304, 191.6655, 191.6715,
         191.8163, 191.5261, 191.5945],
        [192.7785, 192.6212, 191.6827, 191.6035, 191.6670, 191.9037, 191.9086,
         192.0511, 191.7636, 191.8332],
        [192.6733, 192.5138, 191.5709, 191.4966, 191.5621, 191.7985, 191.8013,
         191.9461, 191.6560, 191.7236],
        [192.3221, 192.1650, 191.2246, 191.1468, 191.2102, 191.4483, 191.4522,
         191.5969, 191.3063, 191.3758],
        [192.7242, 192.5653, 191.6246, 191.5500, 191.6133, 191.8485, 191.8534,
         191.9991, 191.7094, 191.7756],
        [192.6257, 192.4678, 191.5295, 191.4520, 191.5155, 191.7500, 191.7570,
         191.9010, 191.6114, 191.6803],
        [192.6615, 192.5036, 191.5651, 191.4875, 191.5511, 191.7860, 191.7924,
         191.9362, 191.6471, 191.7160],
        [192.6179, 192.4602, 191.5194, 191.4435, 191.50

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.1438, 191.5788, 191.5113, 191.5277, 191.8092, 191.1077, 191.7798,
         192.6130, 191.4910, 191.8173],
        [191.9671, 191.4000, 191.3315, 191.3502, 191.6320, 190.9296, 191.6032,
         192.4348, 191.3151, 191.6398],
        [192.2463, 191.6809, 191.6137, 191.6301, 191.9112, 191.2121, 191.8817,
         192.7148, 191.5931, 191.9192],
        [192.0784, 191.5121, 191.4437, 191.4620, 191.7435, 191.0417, 191.7159,
         192.5468, 191.4270, 191.7518],
        [192.3217, 191.7554, 191.6884, 191.7055, 191.9861, 191.2879, 191.9580,
         192.7898, 191.6692, 191.9943],
        [192.1353, 191.5697, 191.5015, 191.5194, 191.8005, 191.0999, 191.7730,
         192.6040, 191.4839, 191.8091],
        [192.0484, 191.4807, 191.4124, 191.4314, 191.7131, 191.0105, 191.6850,
         192.5163, 191.3966, 191.7208],
        [192.0466, 191

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.4966, 191.5095, 191.5446, 191.4557, 191.1898, 191.4735, 191.7431,
         191.7759, 192.2999, 190.9243],
        [191.3220, 191.3370, 191.3727, 191.2815, 191.0158, 191.3014, 191.5728,
         191.6052, 192.1244, 190.7504],
        [191.3608, 191.3742, 191.4089, 191.3195, 191.0533, 191.3375, 191.6079,
         191.6406, 192.1638, 190.7884],
        [191.1093, 191.1231, 191.1584, 191.0682, 190.8014, 191.0870, 191.3588,
         191.3911, 191.9108, 190.5374],
        [191.3253, 191.3397, 191.3752, 191.2847, 191.0188, 191.3043, 191.5752,
         191.6074, 192.1274, 190.7535],
        [191.3992, 191.4145, 191.4499, 191.3583, 191.0931, 191.3783, 191.6496,
         191.6822, 192.2022, 190.8274],
        [191.3937, 191.4081, 191.4438, 191.3532, 191.0877, 191.3721, 191.6432,
         191.6763, 192.1969, 190.8222],
        [191.4754, 191

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.2482, 191.4159, 192.1663, 191.7552, 191.9473, 191.4200, 191.7217,
         191.3627, 191.3283, 190.8144],
        [191.3573, 191.5286, 192.2791, 191.8678, 192.0559, 191.5339, 191.8336,
         191.4748, 191.4430, 190.9239],
        [191.1708, 191.3395, 192.0905, 191.6797, 191.8697, 191.3456, 191.6459,
         191.2865, 191.2536, 190.7373],
        [191.3587, 191.5293, 192.2802, 191.8686, 192.0564, 191.5340, 191.8348,
         191.4758, 191.4433, 190.9249],
        [190.9597, 191.1297, 191.8825, 191.4723, 191.6591, 191.1357, 191.4370,
         191.0766, 191.0440, 190.5274],
        [191.2883, 191.4588, 192.2110, 191.7994, 191.9866, 191.4604, 191.7646,
         191.4050, 191.3707, 190.8554],
        [191.1745, 191.3464, 192.0990, 191.6879, 191.8730, 191.3499, 191.6525,
         191.2925, 191.2597, 190.7421],
        [191.2287, 191

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.9895, 191.7686, 192.3338, 191.4966, 191.3606, 191.6625, 191.7242,
         191.5676, 191.4090, 191.5468],
        [190.9764, 191.7560, 192.3225, 191.4839, 191.3472, 191.6487, 191.7113,
         191.5546, 191.3964, 191.5339],
        [191.1403, 191.9145, 192.4834, 191.6472, 191.5067, 191.8095, 191.8726,
         191.7147, 191.5576, 191.6972],
        [190.8782, 191.6582, 192.2244, 191.3858, 191.2492, 191.5502, 191.6132,
         191.4565, 191.2979, 191.4363],
        [191.0672, 191.8471, 192.4117, 191.5741, 191.4392, 191.7411, 191.8024,
         191.6461, 191.4873, 191.6241],
        [190.9001, 191.6793, 192.2453, 191.4080, 191.2718, 191.5745, 191.6355,
         191.4782, 191.3209, 191.4571],
        [191.1470, 191.9212, 192.4890, 191.6538, 191.5144, 191.8179, 191.8796,
         191.7217, 191.5647, 191.7033],
        [191.0991, 191

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.6007, 191.6048, 192.2414, 191.5158, 191.4533, 190.7569, 192.1765,
         191.1309, 191.2972, 191.5483],
        [191.5942, 191.5993, 192.2363, 191.5110, 191.4485, 190.7510, 192.1701,
         191.1215, 191.2911, 191.5432],
        [191.8148, 191.8185, 192.4553, 191.7296, 191.6694, 190.9709, 192.3895,
         191.3423, 191.5109, 191.7612],
        [191.6425, 191.6467, 192.2841, 191.5582, 191.4951, 190.7994, 192.2181,
         191.1700, 191.3392, 191.5909],
        [191.8974, 191.9020, 192.5385, 191.8132, 191.7513, 191.0546, 192.4736,
         191.4274, 191.5945, 191.8456],
        [191.6612, 191.6660, 192.3028, 191.5776, 191.5163, 190.8178, 192.2366,
         191.1880, 191.3578, 191.6095],
        [191.7370, 191.7408, 192.3774, 191.6518, 191.5899, 190.8934, 192.3125,
         191.2666, 191.4334, 191.6842],
        [191.6650, 191

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.4442, 191.6053, 190.9305, 191.7498, 191.7461, 191.4608, 191.5999,
         191.5750, 191.3905, 191.1289],
        [191.2204, 191.3850, 190.7086, 191.5286, 191.5224, 191.2390, 191.3754,
         191.3495, 191.1670, 190.9068],
        [191.2939, 191.4575, 190.7831, 191.6021, 191.5959, 191.3128, 191.4512,
         191.4251, 191.2415, 190.9810],
        [191.2653, 191.4263, 190.7510, 191.5726, 191.5664, 191.2835, 191.4208,
         191.3959, 191.2116, 190.9510],
        [191.2674, 191.4326, 190.7566, 191.5757, 191.5696, 191.2860, 191.4224,
         191.3963, 191.2143, 190.9538],
        [191.2742, 191.4388, 190.7632, 191.5824, 191.5763, 191.2928, 191.4300,
         191.4039, 191.2212, 190.9609],
        [191.5659, 191.7276, 191.0541, 191.8723, 191.8662, 191.5836, 191.7203,
         191.6958, 191.5123, 191.2504],
        [191.2505, 191

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.7908, 191.8886, 191.1557, 191.8666, 191.4908, 191.1422, 190.9512,
         191.8042, 191.8101, 191.6541],
        [191.6393, 191.7382, 191.0042, 191.7180, 191.3386, 190.9890, 190.7989,
         191.6521, 191.6582, 191.5049],
        [191.5149, 191.6120, 190.8777, 191.5908, 191.2138, 190.8651, 190.6738,
         191.5275, 191.5337, 191.3772],
        [191.5933, 191.6911, 190.9577, 191.6711, 191.2931, 190.9434, 190.7536,
         191.6072, 191.6117, 191.4580],
        [191.4719, 191.5698, 190.8354, 191.5500, 191.1706, 190.8204, 190.6308,
         191.4847, 191.4906, 191.3361],
        [191.4530, 191.5497, 190.8152, 191.5295, 191.1510, 190.8002, 190.6113,
         191.4656, 191.4720, 191.3150],
        [191.6021, 191.6982, 190.9646, 191.6764, 191.3009, 190.9518, 190.7612,
         191.6151, 191.6211, 191.4627],
        [191.3678, 191

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.6714, 191.8219, 191.4444, 191.6452, 192.5405, 191.1158, 191.7654,
         191.5522, 191.8150, 191.0076],
        [191.3686, 191.5192, 191.1397, 191.3397, 192.2365, 190.8116, 191.4631,
         191.2480, 191.5127, 190.7050],
        [191.1324, 191.2824, 190.9035, 191.1042, 192.0008, 190.5739, 191.2257,
         191.0101, 191.2756, 190.4673],
        [191.2391, 191.3889, 191.0109, 191.2117, 192.1080, 190.6808, 191.3318,
         191.1171, 191.3817, 190.5736],
        [191.4247, 191.5759, 191.1977, 191.3978, 192.2940, 190.8679, 191.5199,
         191.3053, 191.5695, 190.7620],
        [191.4392, 191.5893, 191.2124, 191.4140, 192.3090, 190.8820, 191.5322,
         191.3186, 191.5820, 190.7741],
        [191.4961, 191.6474, 191.2692, 191.4700, 192.3651, 190.9399, 191.5916,
         191.3771, 191.6412, 190.8336],
        [191.2906, 191

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.2684, 191.3501, 191.4947, 191.4595, 191.4776, 191.9965, 191.1450,
         191.5238, 192.3448, 190.7305],
        [192.1735, 191.2527, 191.3966, 191.3628, 191.3815, 191.8981, 191.0485,
         191.4283, 192.2468, 190.6352],
        [192.3396, 191.4216, 191.5663, 191.5308, 191.5495, 192.0676, 191.2156,
         191.5955, 192.4164, 190.8036],
        [192.2471, 191.3296, 191.4739, 191.4384, 191.4570, 191.9755, 191.1234,
         191.5023, 192.3237, 190.7099],
        [192.3152, 191.3962, 191.5387, 191.5051, 191.5246, 192.0398, 191.1898,
         191.5690, 192.3886, 190.7793],
        [192.2069, 191.2868, 191.4300, 191.3963, 191.4156, 191.9312, 191.0814,
         191.4614, 192.2801, 190.6698],
        [192.3371, 191.4167, 191.5605, 191.5265, 191.5459, 192.0613, 191.2112,
         191.5924, 192.4110, 190.8015],
        [192.3146, 191

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.8691, 191.2568, 191.1834, 191.5537, 191.5424, 191.3910, 191.3258,
         191.7663, 191.6132, 191.2105],
        [191.8943, 191.2814, 191.2083, 191.5789, 191.5676, 191.4161, 191.3498,
         191.7895, 191.6369, 191.2341],
        [191.8452, 191.2310, 191.1575, 191.5298, 191.5174, 191.3666, 191.2980,
         191.7371, 191.5855, 191.1826],
        [192.1093, 191.4962, 191.4234, 191.7939, 191.7820, 191.6317, 191.5647,
         192.0040, 191.8512, 191.4492],
        [191.7917, 191.1775, 191.1038, 191.4762, 191.4638, 191.3128, 191.2443,
         191.6834, 191.5320, 191.1290],
        [191.7562, 191.1433, 191.0702, 191.4409, 191.4298, 191.2783, 191.2119,
         191.6521, 191.5001, 191.0964],
        [191.7738, 191.1614, 191.0885, 191.4583, 191.4479, 191.2953, 191.2302,
         191.6701, 191.5176, 191.1139],
        [191.5752, 190.9617, 190.8880, 191.2598, 191.24

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.4205, 191.5015, 191.1761, 190.5942, 191.7229, 191.3007, 191.2831,
         191.5278, 190.6846, 191.0613],
        [191.3597, 191.4384, 191.1132, 190.5298, 191.6576, 191.2375, 191.2185,
         191.4638, 190.6178, 190.9971],
        [191.3330, 191.4116, 191.0860, 190.5029, 191.6306, 191.2115, 191.1918,
         191.4370, 190.5909, 190.9705],
        [191.4447, 191.5241, 191.1998, 190.6151, 191.7425, 191.3237, 191.3038,
         191.5498, 190.7047, 191.0836],
        [191.1404, 191.2196, 190.8936, 190.3117, 191.4400, 191.0184, 191.0000,
         191.2449, 190.3987, 190.7773],
        [191.3155, 191.3949, 191.0681, 190.4881, 191.6177, 191.1931, 191.1772,
         191.4205, 190.5760, 190.9536],
        [191.3391, 191.4187, 191.0922, 190.5114, 191.6404, 191.2183, 191.2007,
         191.4447, 190.6002, 190.9781],
        [191.5802, 191

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.2253, 191.0797, 191.1497, 191.6942, 192.0016, 190.9736, 191.0966,
         191.2953, 191.5904, 191.3429],
        [191.2790, 191.1335, 191.2036, 191.7478, 192.0552, 191.0282, 191.1506,
         191.3494, 191.6439, 191.3960],
        [191.2496, 191.1040, 191.1750, 191.7182, 192.0275, 191.0021, 191.1217,
         191.3207, 191.6145, 191.3655],
        [191.3452, 191.1994, 191.2706, 191.8135, 192.1227, 191.0959, 191.2172,
         191.4149, 191.7108, 191.4617],
        [191.4830, 191.3394, 191.4095, 191.9537, 192.2623, 191.2379, 191.3574,
         191.5554, 191.8479, 191.5984],
        [191.3191, 191.1730, 191.2437, 191.7869, 192.0948, 191.0687, 191.1904,
         191.3894, 191.6838, 191.4353],
        [191.1853, 191.0413, 191.1109, 191.6561, 191.9645, 190.9382, 191.0584,
         191.2573, 191.5502, 191.3020],
        [191.3679, 191

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.4270, 191.4691, 192.0012, 191.5501, 191.2493, 192.1906, 191.5299,
         191.0553, 191.7060, 191.5002],
        [191.2949, 191.3381, 191.8712, 191.4193, 191.1185, 192.0605, 191.3991,
         190.9240, 191.5760, 191.3697],
        [191.2419, 191.2834, 191.8152, 191.3642, 191.0644, 192.0057, 191.3458,
         190.8706, 191.5191, 191.3142],
        [191.5034, 191.5461, 192.0781, 191.6263, 191.3268, 192.2680, 191.6053,
         191.1328, 191.7828, 191.5768],
        [191.1846, 191.2261, 191.7585, 191.3076, 191.0069, 191.9485, 191.2896,
         190.8128, 191.4630, 191.2575],
        [191.4661, 191.5093, 192.0415, 191.5893, 191.2901, 192.2315, 191.5681,
         191.0962, 191.7456, 191.5398],
        [191.3499, 191.3937, 191.9260, 191.4740, 191.1735, 192.1152, 191.4522,
         190.9795, 191.6301, 191.4245],
        [191.3202, 191

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.3288, 191.9874, 191.6619, 191.7517, 191.5102, 191.3914, 191.4409,
         191.3535, 191.0038, 191.3591],
        [191.0979, 191.7556, 191.4308, 191.5203, 191.2793, 191.1597, 191.2121,
         191.1230, 190.7735, 191.1295],
        [191.0201, 191.6814, 191.3552, 191.4439, 191.2038, 191.0835, 191.1346,
         191.0462, 190.6964, 191.0522],
        [191.1236, 191.7807, 191.4561, 191.5464, 191.3039, 191.1846, 191.2357,
         191.1473, 190.7980, 191.1533],
        [191.3210, 191.9804, 191.6547, 191.7443, 191.5032, 191.3842, 191.4334,
         191.3460, 190.9965, 191.3518],
        [190.9377, 191.5982, 191.2727, 191.3617, 191.1227, 191.0009, 191.0536,
         190.9642, 190.6156, 190.9721],
        [191.2920, 191.9518, 191.6262, 191.7159, 191.4758, 191.3557, 191.4054,
         191.3177, 190.9689, 191.3246],
        [190.9554, 191

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.5693, 190.9361, 191.4892, 191.6448, 190.7613, 192.1454, 191.3076,
         192.2938, 191.5057, 191.5883],
        [191.2936, 190.6583, 191.2132, 191.3693, 190.4829, 191.8689, 191.0301,
         192.0174, 191.2293, 191.3121],
        [191.3624, 190.7274, 191.2813, 191.4381, 190.5502, 191.9375, 191.0989,
         192.0851, 191.2964, 191.3811],
        [191.3722, 190.7385, 191.2911, 191.4497, 190.5624, 191.9482, 191.1093,
         192.0972, 191.3074, 191.3903],
        [191.3468, 190.7122, 191.2660, 191.4235, 190.5379, 191.9222, 191.0835,
         192.0708, 191.2830, 191.3657],
        [191.4559, 190.8227, 191.3747, 191.5336, 190.6477, 192.0317, 191.1934,
         192.1806, 191.3916, 191.4746],
        [191.2354, 190.6000, 191.1548, 191.3114, 190.4240, 191.8107, 190.9716,
         191.9592, 191.1708, 191.2536],
        [191.3717, 190

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.0143, 191.3838, 190.6612, 191.2107, 190.8290, 191.1522, 191.2676,
         191.0006, 191.3295, 190.9091],
        [190.7300, 191.0996, 190.3757, 190.9265, 190.5430, 190.8676, 190.9822,
         190.7163, 191.0449, 190.6247],
        [190.9805, 191.3496, 190.6266, 191.1777, 190.7937, 191.1172, 191.2325,
         190.9691, 191.2956, 190.8777],
        [191.0377, 191.4057, 190.6835, 191.2339, 190.8499, 191.1750, 191.2897,
         191.0229, 191.3520, 190.9328],
        [191.0712, 191.4413, 190.7189, 191.2678, 190.8870, 191.2090, 191.3246,
         191.0579, 191.3866, 190.9661],
        [190.9671, 191.3370, 190.6139, 191.1643, 190.7816, 191.1041, 191.2195,
         190.9557, 191.2825, 190.8638],
        [190.9835, 191.3517, 190.6292, 191.1791, 190.7964, 191.1216, 191.2366,
         190.9679, 191.2981, 190.8773],
        [191.1143, 191.4820, 190.7603, 191.3099, 190.92

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.2689, 190.7408, 191.1180, 190.9172, 191.7646, 190.9912, 191.1384,
         190.8969, 190.9569, 191.0484],
        [191.6653, 191.1395, 191.5127, 191.3137, 192.1611, 191.3871, 191.5342,
         191.2926, 191.3543, 191.4447],
        [191.5445, 191.0163, 191.3919, 191.1917, 192.0388, 191.2643, 191.4122,
         191.1704, 191.2309, 191.3215],
        [191.6627, 191.1360, 191.5112, 191.3122, 192.1582, 191.3852, 191.5330,
         191.2927, 191.3512, 191.4423],
        [191.5413, 191.0150, 191.3893, 191.1895, 192.0360, 191.2639, 191.4103,
         191.1699, 191.2294, 191.3204],
        [191.6123, 191.0859, 191.4608, 191.2615, 192.1069, 191.3354, 191.4824,
         191.2426, 191.3005, 191.3917],
        [191.5909, 191.0614, 191.4387, 191.2386, 192.0848, 191.3099, 191.4592,
         191.2177, 191.2764, 191.3671],
        [191.4512, 190

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.0531, 190.8878, 191.8537, 191.0646, 191.0814, 190.6818, 190.9170,
         191.3004, 191.4164, 191.9568],
        [190.9756, 190.8121, 191.7760, 190.9883, 191.0062, 190.6059, 190.8387,
         191.2260, 191.3401, 191.8813],
        [190.8786, 190.7129, 191.6789, 190.8900, 190.9065, 190.5059, 190.7417,
         191.1259, 191.2415, 191.7820],
        [190.8786, 190.7159, 191.6783, 190.8922, 190.9096, 190.5109, 190.7418,
         191.1302, 191.2439, 191.7848],
        [190.8406, 190.6756, 191.6408, 190.8522, 190.8693, 190.4684, 190.7036,
         191.0888, 191.2039, 191.7447],
        [190.9745, 190.8098, 191.7746, 190.9862, 191.0033, 190.6040, 190.8385,
         191.2225, 191.3383, 191.8786],
        [190.9750, 190.8094, 191.7752, 190.9859, 191.0028, 190.6028, 190.8387,
         191.2218, 191.3379, 191.8783],
        [190.9504, 190

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.1892, 191.2138, 191.0489, 191.2341, 191.5325, 191.4301, 191.3454,
         190.7078, 191.5222, 191.4498],
        [190.8160, 190.8399, 190.6745, 190.8599, 191.1588, 191.0560, 190.9699,
         190.3340, 191.1489, 191.0762],
        [190.9501, 190.9734, 190.8096, 190.9935, 191.2918, 191.1897, 191.1034,
         190.4670, 191.2824, 191.2103],
        [190.8922, 190.9162, 190.7516, 190.9365, 191.2354, 191.1325, 191.0468,
         190.4103, 191.2254, 191.1532],
        [191.1968, 191.2221, 191.0568, 191.2424, 191.5406, 191.4386, 191.3542,
         190.7161, 191.5302, 191.4583],
        [190.9136, 190.9393, 190.7724, 190.9599, 191.2593, 191.1557, 191.0712,
         190.4344, 191.2482, 191.1755],
        [190.8679, 190.8920, 190.7266, 190.9131, 191.2125, 191.1083, 191.0229,
         190.3871, 191.2016, 191.1288],
        [190.6964, 190

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[192.0377, 190.9053, 190.9321, 191.0002, 191.0471, 191.1077, 191.1364,
         191.4730, 190.5609, 190.7012],
        [191.8922, 190.7586, 190.7854, 190.8541, 190.9001, 190.9581, 190.9903,
         191.3286, 190.4153, 190.5524],
        [191.8878, 190.7559, 190.7854, 190.8503, 190.8972, 190.9559, 190.9870,
         191.3280, 190.4113, 190.5493],
        [191.9417, 190.8099, 190.8391, 190.9042, 190.9515, 191.0103, 191.0412,
         191.3811, 190.4653, 190.6041],
        [191.8722, 190.7389, 190.7655, 190.8339, 190.8806, 190.9393, 190.9705,
         191.3079, 190.3948, 190.5341],
        [191.9303, 190.7973, 190.8251, 190.8920, 190.9392, 190.9959, 191.0295,
         191.3680, 190.4534, 190.5921],
        [191.7402, 190.6070, 190.6337, 190.7021, 190.7483, 190.8087, 190.8376,
         191.1767, 190.2624, 190.4014],
        [191.8775, 190.7445, 190.7722, 190.8398, 190.88

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.9741, 190.8220, 191.1888, 191.1723, 191.0416, 190.8833, 191.0845,
         191.1295, 191.2599, 192.0489],
        [190.9372, 190.7846, 191.1514, 191.1357, 191.0022, 190.8455, 191.0461,
         191.0926, 191.2228, 192.0128],
        [190.9721, 190.8206, 191.1879, 191.1704, 191.0396, 190.8819, 191.0826,
         191.1283, 191.2583, 192.0465],
        [190.9184, 190.7651, 191.1319, 191.1165, 190.9850, 190.8267, 191.0276,
         191.0730, 191.2037, 191.9937],
        [191.1446, 190.9936, 191.3597, 191.3427, 191.2107, 191.0537, 191.2554,
         191.3004, 191.4304, 192.2195],
        [190.6332, 190.4795, 190.8484, 190.8324, 190.7015, 190.5420, 190.7429,
         190.7883, 190.9197, 191.7078],
        [190.8592, 190.7063, 191.0738, 191.0576, 190.9250, 190.7680, 190.9677,
         191.0151, 191.1449, 191.9347],
        [191.0489, 190

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.9010, 191.1310, 190.8287, 190.7597, 191.0659, 191.2268, 190.3223,
         190.6981, 190.4297, 190.7293],
        [190.8980, 191.1273, 190.8251, 190.7567, 191.0614, 191.2234, 190.3188,
         190.6937, 190.4243, 190.7254],
        [191.0729, 191.3037, 191.0018, 190.9305, 191.2361, 191.3962, 190.4940,
         190.8698, 190.6021, 190.9019],
        [190.6546, 190.8848, 190.5816, 190.5119, 190.8185, 190.9790, 190.0762,
         190.4496, 190.1820, 190.4831],
        [190.9930, 191.2238, 190.9213, 190.8519, 191.1580, 191.3184, 190.4144,
         190.7907, 190.5228, 190.8218],
        [190.9959, 191.2258, 190.9237, 190.8549, 191.1593, 191.3212, 190.4166,
         190.7923, 190.5230, 190.8238],
        [190.7097, 190.9383, 190.6363, 190.5667, 190.8720, 191.0337, 190.1307,
         190.5037, 190.2347, 190.5370],
        [190.9594, 191

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.1018, 190.9632, 191.8893, 190.7761, 190.2882, 191.5446, 191.3194,
         190.2558, 191.1277, 191.0438],
        [191.0309, 190.8905, 191.8178, 190.7055, 190.2159, 191.4714, 191.2513,
         190.1840, 191.0567, 190.9717],
        [191.0802, 190.9417, 191.8667, 190.7543, 190.2652, 191.5210, 191.2999,
         190.2336, 191.1065, 191.0209],
        [191.0751, 190.9375, 191.8623, 190.7489, 190.2615, 191.5181, 191.2924,
         190.2292, 191.1010, 191.0170],
        [190.9165, 190.7768, 191.7037, 190.5904, 190.1022, 191.3584, 191.1357,
         190.0697, 190.9419, 190.8579],
        [191.0921, 190.9541, 191.8789, 190.7655, 190.2773, 191.5332, 191.3125,
         190.2462, 191.1185, 191.0329],
        [190.8834, 190.7459, 191.6700, 190.5568, 190.0691, 191.3256, 191.1021,
         190.0370, 190.9092, 190.8248],
        [191.1644, 191

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.9145, 190.8312, 190.4051, 191.2472, 190.9219, 190.5807, 191.4394,
         191.2268, 191.1327, 191.1294],
        [190.7843, 190.7002, 190.2783, 191.1191, 190.7917, 190.4523, 191.3125,
         191.0968, 191.0035, 190.9999],
        [190.7639, 190.6803, 190.2574, 191.0991, 190.7715, 190.4315, 191.2906,
         191.0766, 190.9840, 190.9799],
        [190.9569, 190.8738, 190.4500, 191.2915, 190.9643, 190.6240, 191.4815,
         191.2689, 191.1766, 191.1727],
        [190.8443, 190.7585, 190.3355, 191.1784, 190.8480, 190.5113, 191.3698,
         191.1570, 191.0638, 191.0591],
        [190.7112, 190.6282, 190.2024, 191.0455, 190.7193, 190.3777, 191.2368,
         191.0242, 190.9313, 190.9271],
        [190.5161, 190.4327, 190.0068, 190.8507, 190.5242, 190.1827, 191.0432,
         190.8299, 190.7367, 190.7322],
        [190.5632, 190.4781, 190.0545, 190.8980, 190.56

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.7070, 191.1470, 190.4270, 191.6968, 191.0312, 191.0602, 191.5103,
         190.6293, 191.2846, 190.6606],
        [190.3819, 190.8240, 190.1023, 191.3731, 190.7068, 190.7367, 191.1860,
         190.3031, 190.9630, 190.3367],
        [190.6332, 191.0736, 190.3529, 191.6228, 190.9568, 190.9858, 191.4358,
         190.5552, 191.2114, 190.5868],
        [190.8736, 191.3130, 190.5933, 191.8624, 191.1957, 191.2245, 191.6750,
         190.7959, 191.4510, 190.8269],
        [190.6348, 191.0756, 190.3554, 191.6243, 190.9608, 190.9887, 191.4379,
         190.5569, 191.2140, 190.5880],
        [190.9104, 191.3497, 190.6304, 191.8987, 191.2334, 191.2614, 191.7115,
         190.8330, 191.4878, 190.8632],
        [190.6477, 191.0889, 190.3681, 191.6369, 190.9715, 190.9997, 191.4490,
         190.5696, 191.2281, 190.6013],
        [190.7201, 191.1605, 190.4409, 191.7094, 191.04

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.0334, 190.9631, 191.3114, 190.4579, 191.1560, 190.9872, 190.8527,
         191.0720, 191.0760, 191.1910],
        [190.8154, 190.7440, 191.0923, 190.2382, 190.9376, 190.7683, 190.6325,
         190.8541, 190.8575, 190.9724],
        [190.9399, 190.8699, 191.2197, 190.3639, 191.0637, 190.8941, 190.7590,
         190.9790, 190.9835, 191.0977],
        [190.9258, 190.8541, 191.2019, 190.3488, 191.0473, 190.8780, 190.7439,
         190.9637, 190.9672, 191.0828],
        [191.1098, 191.0410, 191.3882, 190.5362, 191.2323, 191.0657, 190.9287,
         191.1495, 191.1530, 191.2679],
        [190.8438, 190.7723, 191.1221, 190.2662, 190.9670, 190.7960, 190.6630,
         190.8817, 190.8864, 191.0011],
        [191.0198, 190.9472, 191.2964, 190.4422, 191.1410, 190.9714, 190.8378,
         191.0573, 191.0607, 191.1765],
        [190.8897, 190.8203, 191.1696, 190.3141, 191.01

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.5869, 190.5941, 190.5588, 191.6087, 190.6097, 190.5880, 190.6767,
         190.2787, 191.6818, 190.5702],
        [190.7619, 190.7679, 190.7337, 191.7825, 190.7853, 190.7637, 190.8501,
         190.4536, 191.8579, 190.7470],
        [190.7987, 190.8063, 190.7711, 191.8209, 190.8215, 190.8002, 190.8878,
         190.4903, 191.8936, 190.7823],
        [190.7077, 190.7145, 190.6775, 191.7281, 190.7305, 190.7097, 190.7970,
         190.3987, 191.8028, 190.6903],
        [190.7545, 190.7633, 190.7246, 191.7759, 190.7770, 190.7561, 190.8442,
         190.4459, 191.8488, 190.7368],
        [190.8321, 190.8419, 190.8009, 191.8532, 190.8544, 190.8340, 190.9221,
         190.5232, 191.9258, 190.8132],
        [190.9245, 190.9330, 190.8952, 191.9457, 190.9473, 190.9265, 191.0131,
         190.6160, 192.0191, 190.9081],
        [190.7967, 190

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.8618, 191.1677, 190.6426, 190.7984, 190.6799, 190.8103, 190.3323,
         190.9820, 191.0166, 190.5886],
        [190.5991, 190.9048, 190.3780, 190.5373, 190.4159, 190.5493, 190.0706,
         190.7195, 190.7546, 190.3254],
        [190.9950, 191.3009, 190.7763, 190.9341, 190.8136, 190.9451, 190.4661,
         191.1143, 191.1490, 190.7230],
        [191.1160, 191.4225, 190.8973, 191.0538, 190.9354, 191.0643, 190.5869,
         191.2345, 191.2710, 190.8441],
        [190.7354, 191.0406, 190.5150, 190.6739, 190.5527, 190.6855, 190.2066,
         190.8553, 190.8900, 190.4622],
        [190.7452, 191.0538, 190.5251, 190.6822, 190.5635, 190.6939, 190.2171,
         190.8647, 190.9032, 190.4722],
        [190.9606, 191.2662, 190.7409, 190.8974, 190.7792, 190.9084, 190.4313,
         191.0796, 191.1155, 190.6877],
        [190.9958, 191.3010, 190.7765, 190.9346, 190.81

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.0242, 191.1676, 191.0443, 190.8730, 191.9596, 190.3620, 190.7384,
         191.0864, 191.7378, 191.4981],
        [190.7272, 190.8718, 190.7490, 190.5764, 191.6633, 190.0654, 190.4404,
         190.7898, 191.4416, 191.2015],
        [190.9006, 191.0433, 190.9201, 190.7493, 191.8354, 190.2374, 190.6138,
         190.9627, 191.6137, 191.3745],
        [190.8329, 190.9761, 190.8533, 190.6808, 191.7695, 190.1694, 190.5467,
         190.8964, 191.5471, 191.3087],
        [190.7043, 190.8474, 190.7235, 190.5538, 191.6377, 190.0406, 190.4162,
         190.7654, 191.4168, 191.1763],
        [190.7554, 190.8976, 190.7740, 190.6043, 191.6887, 190.0913, 190.4674,
         190.8167, 191.4676, 191.2279],
        [190.7876, 190.9306, 190.8080, 190.6353, 191.7242, 190.1239, 190.5012,
         190.8512, 191.5017, 191.2634],
        [190.7689, 190

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.0907, 190.8631, 191.0287, 190.8132, 191.8833, 191.0630, 191.6622,
         190.7603, 190.6637, 191.1650],
        [190.7374, 190.5085, 190.6741, 190.4584, 191.5291, 190.7091, 191.3075,
         190.4073, 190.3069, 190.8125],
        [190.9373, 190.7097, 190.8759, 190.6594, 191.7295, 190.9100, 191.5089,
         190.6082, 190.5095, 191.0134],
        [190.9395, 190.7118, 190.8769, 190.6616, 191.7319, 190.9111, 191.5108,
         190.6094, 190.5109, 191.0148],
        [190.8172, 190.5892, 190.7553, 190.5388, 191.6092, 190.7898, 191.3884,
         190.4885, 190.3883, 190.8938],
        [190.8181, 190.5888, 190.7543, 190.5391, 191.6102, 190.7905, 191.3881,
         190.4883, 190.3877, 190.8931],
        [190.9776, 190.7496, 190.9150, 190.6997, 191.7703, 190.9503, 191.5489,
         190.6488, 190.5491, 191.0541],
        [190.9840, 190

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.8732, 191.0792, 190.2934, 190.1176, 190.1198, 191.1488, 190.7524,
         191.5576, 190.3149, 190.7107],
        [190.8310, 191.0368, 190.2507, 190.0743, 190.0774, 191.1065, 190.7096,
         191.5160, 190.2727, 190.6693],
        [190.7239, 190.9287, 190.1436, 189.9683, 189.9704, 190.9996, 190.6031,
         191.4095, 190.1671, 190.5639],
        [190.7767, 190.9814, 190.1972, 190.0231, 190.0231, 191.0533, 190.6576,
         191.4617, 190.2201, 190.6159],
        [190.5628, 190.7677, 189.9819, 189.8055, 189.8090, 190.8397, 190.4431,
         191.2478, 190.0050, 190.4018],
        [190.5297, 190.7340, 189.9487, 189.7729, 189.7760, 190.8062, 190.4095,
         191.2155, 189.9730, 190.3702],
        [190.6144, 190.8189, 190.0343, 189.8589, 189.8608, 190.8918, 190.4951,
         191.3005, 190.0581, 190.4555],
        [190.9627, 191

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.7723, 190.4034, 190.9244, 190.6028, 190.0710, 190.6279, 191.3409,
         191.4043, 191.5628, 189.8790],
        [190.5301, 190.1589, 190.6815, 190.3568, 189.8289, 190.3872, 191.0979,
         191.1627, 191.3192, 189.6348],
        [190.8515, 190.4810, 191.0058, 190.6786, 190.1514, 190.7077, 191.4202,
         191.4846, 191.6426, 189.9569],
        [190.6373, 190.2662, 190.7879, 190.4641, 189.9358, 190.4938, 191.2047,
         191.2692, 191.4257, 189.7417],
        [190.8368, 190.4671, 190.9881, 190.6658, 190.1364, 190.6933, 191.4049,
         191.4684, 191.6261, 189.9424],
        [191.0679, 190.6984, 191.2219, 190.8963, 190.3694, 190.9248, 191.6369,
         191.7003, 191.8589, 190.1737],
        [190.9083, 190.5388, 191.0606, 190.7373, 190.2077, 190.7641, 191.4766,
         191.5401, 191.6982, 190.0142],
        [190.9858, 190

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.9777, 190.5429, 190.3574, 190.5611, 190.5232, 191.0470, 190.8618,
         190.5211, 191.4564, 190.1919],
        [190.9030, 190.4696, 190.2864, 190.4878, 190.4491, 190.9759, 190.7877,
         190.4480, 191.3827, 190.1194],
        [190.9597, 190.5266, 190.3445, 190.5451, 190.5066, 191.0344, 190.8441,
         190.5053, 191.4397, 190.1772],
        [190.8115, 190.3776, 190.1929, 190.3959, 190.3579, 190.8825, 190.6976,
         190.3567, 191.2922, 190.0267],
        [191.0719, 190.6376, 190.4534, 190.6567, 190.6192, 191.1427, 190.9565,
         190.6171, 191.5521, 190.2878],
        [190.6750, 190.2413, 190.0565, 190.2592, 190.2210, 190.7467, 190.5618,
         190.2203, 191.1560, 189.8900],
        [191.0726, 190.6386, 190.4542, 190.6573, 190.6192, 191.1427, 190.9568,
         190.6174, 191.5522, 190.2882],
        [190.8273, 190

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.5261, 190.5536, 190.4841, 190.7097, 190.6339, 190.0665, 190.4187,
         189.8043, 190.6459, 190.5745],
        [190.8411, 190.8679, 190.7974, 191.0231, 190.9474, 190.3811, 190.7323,
         190.1175, 190.9587, 190.8881],
        [190.8133, 190.8405, 190.7694, 190.9944, 190.9194, 190.3525, 190.7037,
         190.0888, 190.9301, 190.8595],
        [190.3690, 190.3973, 190.3285, 190.5531, 190.4773, 189.9091, 190.2618,
         189.6479, 190.4890, 190.4176],
        [190.6409, 190.6694, 190.5986, 190.8222, 190.7480, 190.1800, 190.5315,
         189.9173, 190.7578, 190.6877],
        [190.5186, 190.5467, 190.4762, 190.7003, 190.6249, 190.0563, 190.4090,
         189.7939, 190.6359, 190.5639],
        [190.6429, 190.6704, 190.6008, 190.8252, 190.7491, 190.1818, 190.5337,
         189.9190, 190.7601, 190.6890],
        [190.6273, 190.6546, 190.5832, 190.8088, 190.73

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.5394, 190.7659, 190.7763, 190.7610, 190.9439, 191.6239, 190.7673,
         190.9082, 191.0010, 191.0596],
        [191.6220, 190.8473, 190.8577, 190.8470, 191.0268, 191.7069, 190.8517,
         190.9902, 191.0857, 191.1442],
        [191.2516, 190.4760, 190.4875, 190.4770, 190.6574, 191.3375, 190.4832,
         190.6201, 190.7165, 190.7727],
        [191.2867, 190.5125, 190.5236, 190.5095, 190.6919, 191.3721, 190.5165,
         190.6555, 190.7497, 190.8064],
        [191.5022, 190.7283, 190.7388, 190.7251, 190.9070, 191.5866, 190.7304,
         190.8705, 190.9645, 191.0230],
        [191.3782, 190.6017, 190.6133, 190.6008, 190.7838, 191.4615, 190.6056,
         190.7459, 190.8397, 190.8997],
        [191.2793, 190.5036, 190.5151, 190.5041, 190.6850, 191.3648, 190.5102,
         190.6477, 190.7435, 190.8002],
        [191.3261, 190

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.3915, 191.2371, 191.5091, 189.9858, 190.3626, 190.7902, 191.3598,
         191.2589, 190.8262, 191.1200],
        [190.2621, 191.1075, 191.3794, 189.8558, 190.2321, 190.6602, 191.2298,
         191.1272, 190.6960, 190.9898],
        [190.3256, 191.1727, 191.4408, 189.9205, 190.2950, 190.7245, 191.2931,
         191.1903, 190.7585, 191.0546],
        [190.0855, 190.9328, 191.2020, 189.6808, 190.0544, 190.4856, 191.0536,
         190.9521, 190.5194, 190.8152],
        [190.3336, 191.1816, 191.4496, 189.9300, 190.3047, 190.7343, 191.3022,
         191.2016, 190.7685, 191.0648],
        [190.5235, 191.3696, 191.6396, 190.1178, 190.4950, 190.9215, 191.4911,
         191.3879, 190.9574, 191.2520],
        [190.4098, 191.2557, 191.5268, 190.0042, 190.3806, 190.8084, 191.3779,
         191.2768, 190.8440, 191.1383],
        [190.2416, 191

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.3375, 190.2080, 190.6040, 190.2982, 190.5834, 190.2997, 190.1826,
         190.1100, 191.2265, 189.7040],
        [190.6473, 190.5206, 190.9142, 190.6088, 190.8924, 190.6110, 190.4931,
         190.4215, 191.5369, 190.0151],
        [190.5737, 190.4424, 190.8391, 190.5307, 190.8184, 190.5346, 190.4188,
         190.3433, 191.4606, 189.9377],
        [190.7408, 190.6127, 191.0078, 190.6995, 190.9853, 190.7034, 190.5863,
         190.5129, 191.6291, 190.1076],
        [190.5267, 190.3978, 190.7926, 190.4865, 190.7720, 190.4891, 190.3725,
         190.2989, 191.4150, 189.8926],
        [190.4519, 190.3228, 190.7182, 190.4128, 190.6972, 190.4145, 190.2968,
         190.2247, 191.3411, 189.8180],
        [190.7773, 190.6489, 191.0438, 190.7352, 191.0218, 190.7397, 190.6232,
         190.5487, 191.6649, 190.1435],
        [190.5930, 190

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.6778, 190.9064, 190.8279, 191.6070, 190.6626, 191.5217, 190.8057,
         190.2826, 190.4909, 190.0871],
        [190.7220, 190.9503, 190.8715, 191.6510, 190.7071, 191.5660, 190.8497,
         190.3264, 190.5352, 190.1310],
        [190.5072, 190.7339, 190.6566, 191.4361, 190.4901, 191.3501, 190.6339,
         190.1115, 190.3193, 189.9152],
        [190.4622, 190.6901, 190.6101, 191.3898, 190.4473, 191.3064, 190.5888,
         190.0657, 190.2750, 189.8669],
        [190.4146, 190.6423, 190.5646, 191.3438, 190.3982, 191.2578, 190.5415,
         190.0195, 190.2268, 189.8228],
        [190.6243, 190.8503, 190.7727, 191.5522, 190.6072, 191.4675, 190.7510,
         190.2277, 190.4368, 190.0307],
        [190.4400, 190.6693, 190.5891, 191.3680, 190.4254, 191.2846, 190.5672,
         190.0445, 190.2529, 189.8458],
        [190.5496, 190.7803, 190.7000, 191.4787, 190.53

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.8042, 190.6524, 190.4287, 190.3114, 190.4646, 190.4980, 190.6466,
         190.5668, 190.7106, 190.7928],
        [189.5695, 190.4179, 190.1956, 190.0802, 190.2343, 190.2670, 190.4133,
         190.3327, 190.4802, 190.5579],
        [189.7229, 190.5706, 190.3489, 190.2315, 190.3862, 190.4196, 190.5650,
         190.4839, 190.6314, 190.7094],
        [190.0480, 190.8953, 190.6740, 190.5579, 190.7115, 190.7444, 190.8902,
         190.8091, 190.9559, 191.0362],
        [189.6081, 190.4562, 190.2341, 190.1174, 190.2719, 190.3051, 190.4510,
         190.3702, 190.5176, 190.5953],
        [189.6970, 190.5449, 190.3234, 190.2082, 190.3625, 190.3951, 190.5406,
         190.4593, 190.6078, 190.6851],
        [189.7410, 190.5892, 190.3660, 190.2497, 190.4030, 190.4359, 190.5840,
         190.5039, 190.6489, 190.7300],
        [189.8039, 190

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.2455, 190.2758, 190.8636, 191.2047, 190.0483, 190.4245, 190.4235,
         190.2079, 190.3546, 190.2233],
        [190.6773, 190.7063, 191.2951, 191.6375, 190.4801, 190.8562, 190.8547,
         190.6399, 190.7859, 190.6551],
        [190.3702, 190.3993, 190.9863, 191.3283, 190.1726, 190.5480, 190.5468,
         190.3322, 190.4783, 190.3466],
        [190.3578, 190.3868, 190.9738, 191.3159, 190.1599, 190.5359, 190.5344,
         190.3194, 190.4662, 190.3343],
        [190.1381, 190.1692, 190.7575, 191.0983, 189.9410, 190.3180, 190.3171,
         190.1006, 190.2479, 190.1169],
        [190.3766, 190.4077, 190.9973, 191.3384, 190.1799, 190.5574, 190.5562,
         190.3398, 190.4868, 190.3565],
        [190.3726, 190.4045, 190.9943, 191.3351, 190.1762, 190.5539, 190.5529,
         190.3362, 190.4832, 190.3533],
        [190.3598, 190

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.9119, 190.5975, 190.5986, 190.5857, 190.2592, 190.4415, 190.3896,
         190.4131, 190.5113, 189.7797],
        [190.6655, 190.3490, 190.3526, 190.3363, 190.0112, 190.1936, 190.1419,
         190.1655, 190.2640, 189.5324],
        [191.1506, 190.8345, 190.8373, 190.8220, 190.4964, 190.6798, 190.6286,
         190.6502, 190.7485, 190.0178],
        [190.8390, 190.5227, 190.5262, 190.5106, 190.1866, 190.3672, 190.3154,
         190.3415, 190.4372, 189.7061],
        [191.1423, 190.8262, 190.8293, 190.8136, 190.4888, 190.6710, 190.6198,
         190.6425, 190.7403, 190.0095],
        [191.1353, 190.8187, 190.8222, 190.8059, 190.4810, 190.6638, 190.6128,
         190.6347, 190.7330, 190.0023],
        [190.7928, 190.4784, 190.4795, 190.4672, 190.1403, 190.3233, 190.2709,
         190.2947, 190.3921, 189.6606],
        [190.7028, 190

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.3300, 190.5757, 190.4909, 190.6978, 190.3720, 190.3670, 190.3956,
         189.8075, 190.7006, 190.3982],
        [191.2887, 190.5361, 190.4499, 190.6582, 190.3328, 190.3285, 190.3568,
         189.7676, 190.6583, 190.3569],
        [191.0962, 190.3422, 190.2574, 190.4641, 190.1390, 190.1342, 190.1613,
         189.5745, 190.4673, 190.1642],
        [191.3661, 190.6126, 190.5274, 190.7345, 190.4090, 190.4040, 190.4329,
         189.8443, 190.7365, 190.4345],
        [191.4199, 190.6679, 190.5823, 190.7890, 190.4653, 190.4596, 190.4880,
         189.8994, 190.7902, 190.4886],
        [191.2586, 190.5074, 190.4215, 190.6279, 190.3048, 190.2998, 190.3269,
         189.7391, 190.6294, 190.3273],
        [191.2745, 190.5229, 190.4371, 190.6436, 190.3205, 190.3152, 190.3424,
         189.7547, 190.6453, 190.3431],
        [191.1987, 190

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.5302, 190.6772, 190.2660, 190.5103, 190.3214, 190.2878, 191.0487,
         190.4101, 190.5903, 189.7791],
        [190.5053, 190.6545, 190.2414, 190.4852, 190.2975, 190.2659, 191.0257,
         190.3894, 190.5665, 189.7542],
        [190.3789, 190.5251, 190.1134, 190.3591, 190.1696, 190.1395, 190.8946,
         190.2574, 190.4387, 189.6277],
        [190.4673, 190.6163, 190.2034, 190.4464, 190.2594, 190.2305, 190.9862,
         190.3505, 190.5284, 189.7160],
        [190.3232, 190.4702, 190.0577, 190.3033, 190.1141, 190.0850, 190.8395,
         190.2034, 190.3834, 189.5719],
        [190.5140, 190.6611, 190.2501, 190.4943, 190.3051, 190.2713, 191.0327,
         190.3940, 190.5742, 189.7631],
        [190.3755, 190.5244, 190.1115, 190.3554, 190.1670, 190.1351, 190.8955,
         190.2592, 190.4365, 189.6242],
        [190.2719, 190

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.4862, 190.1406, 190.1112, 190.2306, 190.0462, 190.4309, 190.5816,
         190.4805, 190.3658, 189.6764],
        [190.1500, 189.8027, 189.7730, 189.8916, 189.7075, 190.0924, 190.2454,
         190.1415, 190.0270, 189.3375],
        [190.5111, 190.1642, 190.1323, 190.2518, 190.0702, 190.4509, 190.6060,
         190.5032, 190.3879, 189.6978],
        [190.4003, 190.0528, 190.0212, 190.1409, 189.9586, 190.3403, 190.4952,
         190.3923, 190.2773, 189.5866],
        [190.4584, 190.1109, 190.0787, 190.1980, 190.0168, 190.3970, 190.5531,
         190.4495, 190.3343, 189.6441],
        [190.4339, 190.0868, 190.0572, 190.1762, 189.9916, 190.3760, 190.5285,
         190.4258, 190.3117, 189.6227],
        [190.4459, 190.0991, 190.0687, 190.1866, 190.0040, 190.3863, 190.5407,
         190.4364, 190.3217, 189.6339],
        [190.3679, 190

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.2629, 190.5759, 190.4164, 190.3011, 190.5605, 191.0075, 190.3539,
         190.2835, 190.2353, 190.3890],
        [190.3993, 190.7132, 190.5545, 190.4380, 190.6970, 191.1434, 190.4903,
         190.4199, 190.3728, 190.5264],
        [190.0466, 190.3615, 190.2025, 190.0855, 190.3461, 190.7920, 190.1389,
         190.0683, 190.0214, 190.1759],
        [190.0018, 190.3168, 190.1573, 190.0404, 190.3015, 190.7480, 190.0933,
         190.0234, 189.9754, 190.1294],
        [190.3365, 190.6503, 190.4912, 190.3752, 190.6343, 191.0810, 190.4274,
         190.3573, 190.3090, 190.4628],
        [190.0742, 190.3887, 190.2300, 190.1132, 190.3734, 190.8193, 190.1669,
         190.0959, 190.0493, 190.2040],
        [190.0654, 190.3791, 190.2198, 190.1039, 190.3641, 190.8105, 190.1575,
         190.0867, 190.0392, 190.1938],
        [190.1850, 190

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.1352, 190.1558, 190.2200, 189.8707, 190.1358, 190.2257, 190.3626,
         189.9804, 190.2414, 190.1931],
        [190.1103, 190.1309, 190.1929, 189.8462, 190.1114, 190.2003, 190.3373,
         189.9572, 190.2154, 190.1670],
        [190.0920, 190.1124, 190.1756, 189.8275, 190.0928, 190.1821, 190.3192,
         189.9376, 190.1974, 190.1492],
        [190.2005, 190.2188, 190.2830, 189.9337, 190.1996, 190.2872, 190.4268,
         190.0437, 190.3035, 190.2543],
        [190.0271, 190.0460, 190.1114, 189.7614, 190.0270, 190.1154, 190.2546,
         189.8709, 190.1316, 190.0828],
        [190.0857, 190.1054, 190.1702, 189.8199, 190.0856, 190.1757, 190.3128,
         189.9283, 190.1912, 190.1433],
        [190.3808, 190.4008, 190.4643, 190.1161, 190.3811, 190.4692, 190.6074,
         190.2277, 190.4856, 190.4361],
        [190.1779, 190.1989, 190.2624, 189.9135, 190.17

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.6980, 191.1113, 190.1681, 190.4834, 189.5630, 190.3158, 190.3951,
         190.2549, 191.0248, 189.5074],
        [189.6842, 191.0966, 190.1566, 190.4694, 189.5508, 190.3035, 190.3830,
         190.2408, 191.0129, 189.4934],
        [189.6378, 191.0505, 190.1074, 190.4228, 189.5021, 190.2558, 190.3348,
         190.1950, 190.9642, 189.4466],
        [189.3555, 190.7703, 189.8289, 190.1429, 189.2225, 189.9765, 190.0564,
         189.9137, 190.6858, 189.1671],
        [189.6951, 191.1052, 190.1645, 190.4777, 189.5592, 190.3131, 190.3916,
         190.2519, 191.0217, 189.5011],
        [189.6748, 191.0873, 190.1473, 190.4601, 189.5414, 190.2942, 190.3737,
         190.2313, 191.0036, 189.4842],
        [189.8849, 191.2950, 190.3530, 190.6672, 189.7484, 190.5012, 190.5796,
         190.4409, 191.2100, 189.6908],
        [189.9306, 191

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[191.1495, 189.7170, 190.3971, 190.1227, 190.7424, 191.2344, 190.2177,
         190.5343, 190.6142, 190.5569],
        [190.7490, 189.3171, 189.9973, 189.7213, 190.3436, 190.8336, 189.8178,
         190.1356, 190.2149, 190.1581],
        [191.1081, 189.6768, 190.3574, 190.0807, 190.7022, 191.1933, 190.1788,
         190.4953, 190.5747, 190.5160],
        [191.0536, 189.6217, 190.3042, 190.0268, 190.6502, 191.1403, 190.1244,
         190.4401, 190.5213, 190.4611],
        [190.7167, 189.2842, 189.9642, 189.6891, 190.3102, 190.8015, 189.7844,
         190.1022, 190.1819, 190.1257],
        [190.9885, 189.5567, 190.2358, 189.9611, 190.5819, 191.0722, 190.0565,
         190.3743, 190.4531, 190.3970],
        [191.0940, 189.6622, 190.3443, 190.0671, 190.6906, 191.1804, 190.1645,
         190.4803, 190.5613, 190.5014],
        [191.2477, 189

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.2403, 190.2346, 190.5777, 190.5796, 190.3710, 189.7330, 190.8924,
         191.3508, 190.4917, 190.2861],
        [190.0800, 190.0749, 190.4162, 190.4182, 190.2092, 189.5715, 190.7309,
         191.1895, 190.3303, 190.1247],
        [189.7631, 189.7587, 190.1009, 190.1033, 189.8935, 189.2551, 190.4160,
         190.8738, 190.0159, 189.8088],
        [190.2478, 190.2424, 190.5859, 190.5876, 190.3798, 189.7413, 190.9007,
         191.3589, 190.5000, 190.2943],
        [189.8922, 189.8867, 190.2304, 190.2303, 190.0226, 189.3852, 190.5434,
         191.0022, 190.1441, 189.9371],
        [189.9941, 189.9894, 190.3351, 190.3366, 190.1297, 189.4896, 190.6500,
         191.1070, 190.2497, 190.0425],
        [190.0111, 190.0051, 190.3494, 190.3493, 190.1412, 189.5043, 190.6624,
         191.1213, 190.2632, 190.0561],
        [189.9120, 189

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.2081, 190.3613, 190.2315, 189.7325, 190.3257, 190.6046, 190.3621,
         190.6271, 189.9069, 189.7010],
        [189.7683, 189.9224, 189.7930, 189.2911, 189.8864, 190.1670, 189.9233,
         190.1882, 189.4663, 189.2618],
        [189.8364, 189.9900, 189.8591, 189.3581, 189.9532, 190.2344, 189.9882,
         190.2552, 189.5335, 189.3283],
        [189.9297, 190.0840, 189.9539, 189.4546, 190.0475, 190.3280, 190.0849,
         190.3502, 189.6283, 189.4228],
        [189.8824, 190.0366, 189.9075, 189.4063, 190.0012, 190.2808, 190.0383,
         190.3023, 189.5811, 189.3765],
        [189.9581, 190.1115, 189.9818, 189.4803, 190.0758, 190.3556, 190.1114,
         190.3766, 189.6560, 189.4512],
        [189.8875, 190.0407, 189.9104, 189.4097, 190.0040, 190.2852, 190.0403,
         190.3069, 189.5851, 189.3796],
        [190.0040, 190

tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.0466, 189.4295, 189.7089, 190.0402, 189.7554, 189.9031, 189.3867,
         189.8626, 190.0132, 190.7465],
        [190.3601, 189.7424, 190.0224, 190.3537, 190.0694, 190.2154, 189.6999,
         190.1752, 190.3254, 191.0589],
        [190.5233, 189.9053, 190.1859, 190.5165, 190.2325, 190.3784, 189.8641,
         190.3392, 190.4884, 191.2230],
        [190.2146, 189.5965, 189.8744, 190.2078, 189.9230, 190.0709, 189.5544,
         190.0270, 190.1777, 190.9126],
        [190.3184, 189.7006, 189.9796, 190.3117, 190.0270, 190.1749, 189.6598,
         190.1332, 190.2827, 191.0182],
        [190.1458, 189.5291, 189.8099, 190.1396, 189.8551, 190.0018, 189.4863,
         189.9639, 190.1140, 190.8466],
        [190.1322, 189.5140, 189.7936, 190.1251, 189.8407, 189.9876, 189.4712,
         189.9465, 190.0975, 190.8307],
        [190.0850, 189.4682, 189.7462, 190.0787, 189.79

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.3816, 190.1170, 189.9199, 190.7522, 191.0714, 190.2060, 190.0800,
         190.1429, 190.7298, 189.9942],
        [190.1470, 189.8813, 189.6852, 190.5156, 190.8372, 189.9734, 189.8455,
         189.9079, 190.4974, 189.7629],
        [190.0404, 189.7740, 189.5780, 190.4089, 190.7308, 189.8674, 189.7382,
         189.8017, 190.3914, 189.6558],
        [190.2516, 189.9869, 189.7896, 190.6221, 190.9414, 190.0761, 189.9497,
         190.0130, 190.6000, 189.8636],
        [190.2587, 189.9933, 189.7964, 190.6282, 190.9480, 190.0838, 189.9566,
         190.0198, 190.6073, 189.8730],
        [190.3238, 190.0574, 189.8619, 190.6940, 191.0158, 190.1508, 190.0225,
         190.0862, 190.6757, 189.9371],
        [190.2611, 189.9955, 189.7997, 190.6294, 190.9512, 190.0874, 189.9602,
         190.0215, 190.6113, 189.8782],
        [190.2432, 189.9776, 189.7813, 190.6123, 190.93

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.8970, 190.1223, 190.2117, 189.3439, 189.5718, 190.3275, 190.1821,
         190.9606, 190.0449, 190.0382],
        [190.0124, 190.2380, 190.3262, 189.4592, 189.6859, 190.4443, 190.2970,
         191.0761, 190.1591, 190.1507],
        [190.1083, 190.3333, 190.4227, 189.5548, 189.7826, 190.5389, 190.3935,
         191.1712, 190.2568, 190.2466],
        [189.9502, 190.1762, 190.2642, 189.3977, 189.6249, 190.3820, 190.2359,
         191.0135, 190.0999, 190.0876],
        [190.0309, 190.2564, 190.3447, 189.4774, 189.7044, 190.4625, 190.3154,
         191.0945, 190.1773, 190.1692],
        [189.9702, 190.1960, 190.2839, 189.4173, 189.6440, 190.4022, 190.2550,
         191.0340, 190.1176, 190.1086],
        [189.8842, 190.1103, 190.1979, 189.3312, 189.5584, 190.3161, 190.1696,
         190.9479, 190.0328, 190.0217],
        [189.8655, 190

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.9507, 189.8787, 189.7340, 190.3091, 190.1917, 189.8112, 189.9585,
         189.9380, 189.5470, 189.5105],
        [190.5372, 190.4651, 190.3196, 190.8951, 190.7783, 190.3973, 190.5452,
         190.5224, 190.1340, 190.0992],
        [190.1528, 190.0804, 189.9343, 190.5094, 190.3929, 190.0124, 190.1613,
         190.1404, 189.7500, 189.7126],
        [190.2148, 190.1410, 189.9980, 190.5713, 190.4541, 190.0737, 190.2224,
         190.2010, 189.8101, 189.7755],
        [190.3007, 190.2270, 190.0830, 190.6580, 190.5402, 190.1599, 190.3078,
         190.2867, 189.8954, 189.8614],
        [190.3471, 190.2766, 190.1297, 190.7063, 190.5896, 190.2085, 190.3556,
         190.3332, 189.9454, 189.9086],
        [190.3250, 190.2517, 190.1063, 190.6812, 190.5644, 190.1839, 190.3331,
         190.3118, 189.9210, 189.8855],
        [190.2798, 190

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.0761, 190.2271, 189.3537, 190.0153, 189.8769, 190.0134, 190.1928,
         189.9444, 190.0932, 190.4832],
        [190.0866, 190.2394, 189.3639, 190.0285, 189.8875, 190.0251, 190.2029,
         189.9568, 190.1048, 190.4928],
        [189.8463, 189.9992, 189.1236, 189.7867, 189.6469, 189.7834, 189.9633,
         189.7148, 189.8639, 190.2533],
        [189.9871, 190.1408, 189.2643, 189.9291, 189.7882, 189.9249, 190.1031,
         189.8556, 190.0056, 190.3940],
        [190.0404, 190.1937, 189.3175, 189.9824, 189.8415, 189.9789, 190.1560,
         189.9089, 190.0593, 190.4471],
        [190.3293, 190.4804, 189.6068, 190.2697, 190.1308, 190.2677, 190.4441,
         190.1968, 190.3476, 190.7364],
        [190.2095, 190.3595, 189.4864, 190.1488, 190.0102, 190.1490, 190.3247,
         190.0778, 190.2277, 190.6160],
        [190.3010, 190

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.8960, 189.1541, 189.2377, 189.9182, 189.8472, 189.9865, 190.0469,
         190.1275, 189.9110, 189.9005],
        [190.0187, 189.2739, 189.3598, 190.0386, 189.9693, 190.1087, 190.1682,
         190.2473, 190.0314, 190.0212],
        [189.7490, 189.0062, 189.0903, 189.7698, 189.6977, 189.8392, 189.8985,
         189.9788, 189.7616, 189.7502],
        [190.1087, 189.3648, 189.4496, 190.1295, 190.0588, 190.1981, 190.2580,
         190.3379, 190.1209, 190.1110],
        [189.8902, 189.1496, 189.2327, 189.9132, 189.8414, 189.9813, 190.0417,
         190.1228, 189.9056, 189.8949],
        [190.0590, 189.3146, 189.4000, 190.0791, 190.0080, 190.1486, 190.2078,
         190.2875, 190.0702, 190.0598],
        [189.9298, 189.1861, 189.2722, 189.9501, 189.8812, 190.0210, 190.0804,
         190.1594, 189.9439, 189.9332],
        [189.6606, 188

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.9057, 189.8401, 189.8769, 190.8318, 190.2198, 190.9898, 189.8898,
         190.2423, 189.5816, 189.5611],
        [189.5045, 189.4388, 189.4760, 190.4335, 189.8207, 190.5896, 189.4899,
         189.8422, 189.1793, 189.1620],
        [189.8902, 189.8244, 189.8614, 190.8159, 190.2042, 190.9740, 189.8744,
         190.2266, 189.5658, 189.5452],
        [189.9590, 189.8932, 189.9286, 190.8873, 190.2738, 191.0436, 189.9430,
         190.2965, 189.6353, 189.6174],
        [190.0591, 189.9925, 190.0288, 190.9853, 190.3735, 191.1432, 190.0431,
         190.3956, 189.7350, 189.7146],
        [189.7231, 189.6572, 189.6943, 190.6496, 190.0375, 190.8067, 189.7077,
         190.0600, 189.3981, 189.3796],
        [189.7172, 189.6506, 189.6873, 190.6459, 190.0335, 190.8024, 189.7023,
         190.0545, 189.3923, 189.3742],
        [189.9656, 189

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.5977, 190.1254, 190.6536, 190.1044, 189.2040, 190.0977, 189.9175,
         189.7923, 190.0898, 190.0111],
        [189.7481, 190.2766, 190.8036, 190.2547, 189.3533, 190.2478, 190.0676,
         189.9414, 190.2393, 190.1615],
        [189.6664, 190.1945, 190.7222, 190.1722, 189.2733, 190.1662, 189.9884,
         189.8601, 190.1574, 190.0788],
        [189.8418, 190.3699, 190.8975, 190.3478, 189.4478, 190.3413, 190.1623,
         190.0363, 190.3326, 190.2545],
        [189.6116, 190.1403, 190.6675, 190.1195, 189.2173, 190.1122, 189.9310,
         189.8052, 190.1039, 190.0257],
        [189.5272, 190.0554, 190.5832, 190.0345, 189.1331, 190.0276, 189.8466,
         189.7211, 190.0195, 189.9412],
        [189.3957, 189.9232, 190.4512, 189.9009, 189.0017, 189.8952, 189.7161,
         189.5883, 189.8868, 189.8086],
        [189.7063, 190

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.7914, 189.4440, 189.5108, 189.6355, 189.7649, 189.8065, 189.2922,
         190.0226, 189.6348, 189.8756],
        [189.8987, 189.5502, 189.6174, 189.7424, 189.8717, 189.9119, 189.3970,
         190.1285, 189.7414, 189.9809],
        [189.8933, 189.5424, 189.6105, 189.7343, 189.8638, 189.9041, 189.3888,
         190.1209, 189.7357, 189.9742],
        [189.7629, 189.4142, 189.4808, 189.6060, 189.7341, 189.7757, 189.2612,
         189.9933, 189.6057, 189.8455],
        [189.8156, 189.4652, 189.5327, 189.6570, 189.7860, 189.8269, 189.3118,
         190.0441, 189.6582, 189.8971],
        [189.7588, 189.4090, 189.4759, 189.6006, 189.7284, 189.7703, 189.2558,
         189.9884, 189.6016, 189.8410],
        [189.7434, 189.3942, 189.4608, 189.5862, 189.7150, 189.7562, 189.2408,
         189.9728, 189.5859, 189.8254],
        [189.5561, 189

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.3765, 189.5460, 189.2136, 189.6335, 189.5140, 189.7702, 189.7011,
         189.8512, 189.6117, 189.7743],
        [190.1634, 189.3345, 189.0046, 189.4233, 189.3047, 189.5623, 189.4925,
         189.6415, 189.4003, 189.5621],
        [190.4716, 189.6418, 189.3111, 189.7295, 189.6102, 189.8675, 189.7973,
         189.9476, 189.7076, 189.8699],
        [190.4148, 189.5842, 189.2531, 189.6722, 189.5524, 189.8095, 189.7394,
         189.8900, 189.6504, 189.8123],
        [190.5256, 189.6954, 189.3645, 189.7830, 189.6634, 189.9205, 189.8503,
         190.0009, 189.7613, 189.9236],
        [190.7031, 189.8722, 189.5434, 189.9613, 189.8411, 190.0981, 190.0271,
         190.1782, 189.9393, 190.0995],
        [190.6740, 189.8437, 189.5155, 189.9329, 189.8129, 190.0707, 189.9991,
         190.1503, 189.9108, 190.0710],
        [190.5020, 189

tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.0362, 189.8257, 189.9428, 189.5823, 189.6738, 189.5980, 189.5041,
         189.8190, 190.7459, 190.0575],
        [189.9376, 189.7262, 189.8443, 189.4804, 189.5721, 189.4981, 189.4047,
         189.7212, 190.6464, 189.9588],
        [190.0766, 189.8657, 189.9824, 189.6262, 189.7171, 189.6388, 189.5457,
         189.8624, 190.7883, 190.0997],
        [190.0298, 189.8194, 189.9361, 189.5750, 189.6671, 189.5912, 189.4984,
         189.8132, 190.7385, 190.0517],
        [190.0873, 189.8755, 189.9939, 189.6326, 189.7226, 189.6485, 189.5552,
         189.8737, 190.7987, 190.1097],
        [189.7411, 189.5292, 189.6470, 189.2872, 189.3794, 189.3017, 189.2077,
         189.5244, 190.4505, 189.7625],
        [190.0484, 189.8379, 189.9550, 189.5917, 189.6838, 189.6095, 189.5165,
         189.8313, 190.7566, 190.0697],
        [189.9435, 189.7327, 189.8496, 189.4897, 189.58

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.5382, 189.4574, 189.8687, 189.2111, 189.7834, 189.3935, 190.1817,
         189.9405, 189.8624, 189.7769],
        [189.6857, 189.6038, 190.0148, 189.3581, 189.9285, 189.5437, 190.3273,
         190.0887, 190.0094, 189.9223],
        [189.8789, 189.7975, 190.2086, 189.5514, 190.1226, 189.7363, 190.5214,
         190.2814, 190.2025, 190.1162],
        [189.4796, 189.3979, 189.8099, 189.1503, 189.7247, 189.3348, 190.1227,
         189.8816, 189.8036, 189.7172],
        [189.6280, 189.5466, 189.9572, 189.3020, 189.8706, 189.4860, 190.2697,
         190.0313, 189.9520, 189.8654],
        [189.6321, 189.5503, 189.9613, 189.3044, 189.8751, 189.4905, 190.2736,
         190.0355, 189.9561, 189.8686],
        [189.6407, 189.5617, 189.9709, 189.3149, 189.8865, 189.4990, 190.2838,
         190.0453, 189.9658, 189.8792],
        [189.2924, 189

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 2, 3, 3, 2, 3, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[190.2702, 189.5497, 189.6834, 189.4468, 189.6578, 190.5390, 190.0504,
         189.8295, 189.6001, 190.2901],
        [190.2889, 189.5687, 189.7035, 189.4639, 189.6768, 190.5570, 190.0690,
         189.8468, 189.6191, 190.3104],
        [190.2429, 189.5221, 189.6559, 189.4191, 189.6296, 190.5110, 190.0222,
         189.8015, 189.5725, 190.2623],
        [190.3280, 189.6081, 189.7424, 189.5044, 189.7166, 190.5970, 190.1087,
         189.8869, 189.6582, 190.3495],
        [190.1478, 189.4263, 189.5598, 189.3237, 189.5334, 190.4154, 189.9263,
         189.7062, 189.4771, 190.1658],
        [190.3352, 189.6133, 189.7471, 189.5092, 189.7185, 190.6002, 190.1120,
         189.8917, 189.6649, 190.3530],
        [190.0632, 189.3413, 189.4751, 189.2382, 189.4479, 190.3299, 189.8407,
         189.6208, 189.3922, 190.0808],
        [190.2824, 189

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.4728, 189.0937, 188.9888, 189.5116, 189.7292, 189.7178, 189.7114,
         189.7970, 189.3283, 190.5701],
        [189.5667, 189.1867, 189.0811, 189.6049, 189.8233, 189.8095, 189.8035,
         189.8907, 189.4247, 190.6635],
        [189.3152, 188.9327, 188.8303, 189.3533, 189.5706, 189.5599, 189.5525,
         189.6394, 189.1709, 190.4125],
        [189.5098, 189.1286, 189.0240, 189.5477, 189.7657, 189.7526, 189.7463,
         189.8335, 189.3670, 190.6064],
        [189.5978, 189.2194, 189.1126, 189.6364, 189.8552, 189.8407, 189.8354,
         189.9220, 189.4559, 190.6945],
        [189.5626, 189.1837, 189.0769, 189.6009, 189.8200, 189.8049, 189.7998,
         189.8866, 189.4208, 190.6590],
        [189.5442, 189.1637, 189.0603, 189.5831, 189.7997, 189.7898, 189.7821,
         189.8685, 189.4001, 190.6423],
        [189.4997, 189

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.4372, 189.6822, 188.9532, 189.7905, 190.5343, 190.2143, 189.7238,
         189.6686, 189.5265, 189.1290],
        [189.5270, 189.7694, 189.0413, 189.8776, 190.6231, 190.3057, 189.8148,
         189.7582, 189.6149, 189.2199],
        [189.2797, 189.5244, 188.7948, 189.6322, 190.3767, 190.0586, 189.5666,
         189.5096, 189.3688, 188.9712],
        [189.6416, 189.8856, 189.1571, 189.9932, 190.7387, 190.4197, 189.9289,
         189.8734, 189.7307, 189.3348],
        [189.4426, 189.6852, 188.9569, 189.7934, 190.5387, 190.2213, 189.7301,
         189.6732, 189.5305, 189.1349],
        [189.4356, 189.6788, 188.9499, 189.7863, 190.5322, 190.2155, 189.7231,
         189.6656, 189.5241, 189.1283],
        [189.4433, 189.6863, 188.9574, 189.7938, 190.5393, 190.2217, 189.7304,
         189.6722, 189.5310, 189.1341],
        [189.4967, 189

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.7373, 189.8357, 189.7629, 190.5040, 190.2141, 190.0679, 189.8567,
         189.8910, 190.0619, 190.0749],
        [189.3284, 189.4281, 189.3546, 190.0970, 189.8057, 189.6609, 189.4476,
         189.4825, 189.6538, 189.6679],
        [189.5931, 189.6910, 189.6178, 190.3586, 190.0682, 189.9230, 189.7121,
         189.7449, 189.9176, 189.9296],
        [189.4222, 189.5217, 189.4473, 190.1891, 189.8976, 189.7539, 189.5422,
         189.5761, 189.7487, 189.7594],
        [189.5132, 189.6126, 189.5389, 190.2820, 189.9914, 189.8450, 189.6323,
         189.6676, 189.8385, 189.8524],
        [189.4261, 189.5269, 189.4525, 190.1967, 189.9054, 189.7596, 189.5458,
         189.5825, 189.7527, 189.7667],
        [189.4123, 189.5128, 189.4374, 190.1806, 189.8886, 189.7450, 189.5330,
         189.5679, 189.7402, 189.7500],
        [189.4170, 189

tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.7591, 189.5905, 189.5808, 189.5145, 190.3233, 189.8814, 189.6750,
         190.3434, 190.0328, 189.7408],
        [189.6925, 189.5222, 189.5144, 189.4479, 190.2580, 189.8160, 189.6082,
         190.2776, 189.9676, 189.6739],
        [189.5193, 189.3500, 189.3410, 189.2754, 190.0832, 189.6415, 189.4358,
         190.1035, 189.7924, 189.5024],
        [189.3454, 189.1794, 189.1683, 189.1015, 189.9113, 189.4693, 189.2619,
         189.9283, 189.6190, 189.3280],
        [189.5762, 189.4097, 189.3982, 189.3325, 190.1395, 189.6978, 189.4930,
         190.1592, 189.8481, 189.5596],
        [189.8356, 189.6699, 189.6586, 189.5916, 190.4004, 189.9578, 189.7526,
         190.4187, 190.1098, 189.8176],
        [189.5665, 189.4017, 189.3895, 189.3231, 190.1310, 189.6887, 189.4838,
         190.1487, 189.8392, 189.5498],
        [189.5569, 189.3870, 189.3779, 189.3128, 190.11

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.8069, 189.6503, 189.3545, 190.3581, 189.5264, 189.3840, 189.7051,
         188.9088, 189.8868, 190.1965],
        [189.7813, 189.6259, 189.3298, 190.3338, 189.5005, 189.3597, 189.6822,
         188.8840, 189.8641, 190.1727],
        [189.8003, 189.6431, 189.3471, 190.3499, 189.5183, 189.3772, 189.6998,
         188.9007, 189.8787, 190.1908],
        [189.9632, 189.8098, 189.5157, 190.5168, 189.6848, 189.5438, 189.8638,
         189.0675, 190.0456, 190.3537],
        [190.0185, 189.8637, 189.5683, 190.5703, 189.7377, 189.5980, 189.9202,
         189.1214, 190.1005, 190.4106],
        [189.7981, 189.6436, 189.3479, 190.3513, 189.5181, 189.3773, 189.6994,
         188.9013, 189.8812, 190.1894],
        [189.7281, 189.5712, 189.2742, 190.2788, 189.4455, 189.3051, 189.6290,
         188.8287, 189.8089, 190.1197],
        [189.8105, 189

tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.4186, 189.2823, 189.2475, 189.3875, 189.3350, 189.3243, 189.4421,
         189.5575, 189.5820, 189.4671],
        [189.6044, 189.4697, 189.4337, 189.5729, 189.5210, 189.5084, 189.6293,
         189.7437, 189.7704, 189.6553],
        [189.2224, 189.0848, 189.0486, 189.1899, 189.1375, 189.1266, 189.2430,
         189.3612, 189.3850, 189.2705],
        [189.3584, 189.2216, 189.1850, 189.3259, 189.2739, 189.2612, 189.3798,
         189.4977, 189.5227, 189.4081],
        [189.5616, 189.4236, 189.3885, 189.5290, 189.4769, 189.4658, 189.5827,
         189.6998, 189.7246, 189.6098],
        [189.5682, 189.4318, 189.3949, 189.5357, 189.4843, 189.4694, 189.5896,
         189.7080, 189.7337, 189.6194],
        [189.3287, 189.1913, 189.1558, 189.2965, 189.2440, 189.2337, 189.3503,
         189.4672, 189.4914, 189.3766],
        [189.5845, 189.4477, 189.4113, 189.5519, 189.50

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.4439, 190.4426, 189.5300, 189.8210, 189.0403, 189.0915, 189.8886,
         189.5952, 189.4062, 189.2768],
        [189.3918, 190.3894, 189.4769, 189.7676, 188.9863, 189.0386, 189.8350,
         189.5415, 189.3513, 189.2225],
        [189.3236, 190.3192, 189.4056, 189.6965, 188.9154, 188.9693, 189.7603,
         189.4711, 189.2795, 189.1517],
        [189.0656, 190.0639, 189.1503, 189.4419, 188.6583, 188.7123, 189.5070,
         189.2164, 189.0264, 188.8962],
        [189.4062, 190.4039, 189.4907, 189.7815, 188.9998, 189.0530, 189.8455,
         189.5569, 189.3660, 189.2379],
        [189.2514, 190.2497, 189.3369, 189.6280, 188.8456, 188.8984, 189.6946,
         189.4023, 189.2126, 189.0829],
        [189.4194, 190.4173, 189.5048, 189.7956, 189.0146, 189.0665, 189.8630,
         189.5696, 189.3798, 189.2509],
        [189.2937, 190

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.4810, 189.8124, 189.8634, 188.9517, 189.6395, 189.7911, 189.6487,
         189.5686, 190.2196, 190.2368],
        [189.4803, 189.8120, 189.8630, 188.9510, 189.6392, 189.7903, 189.6482,
         189.5674, 190.2194, 190.2377],
        [189.0287, 189.3641, 189.4134, 188.5002, 189.1872, 189.3408, 189.1981,
         189.1182, 189.7699, 189.7884],
        [189.3649, 189.6983, 189.7472, 188.8357, 189.5242, 189.6757, 189.5338,
         189.4525, 190.1054, 190.1210],
        [189.6005, 189.9311, 189.9830, 189.0709, 189.7576, 189.9101, 189.7664,
         189.6861, 190.3390, 190.3597],
        [189.1305, 189.4650, 189.5152, 188.6018, 189.2890, 189.4420, 189.2993,
         189.2192, 189.8712, 189.8910],
        [189.2301, 189.5632, 189.6151, 188.7014, 189.3879, 189.5411, 189.3978,
         189.3182, 189.9697, 189.9921],
        [189.5719, 189.9027, 189.9547, 189.0423, 189.73

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.5580, 189.2812, 189.6394, 189.7814, 189.3066, 189.4606, 189.4940,
         189.4065, 189.3733, 189.3540],
        [189.5468, 189.2672, 189.6260, 189.7704, 189.2920, 189.4480, 189.4795,
         189.3930, 189.3600, 189.3409],
        [189.4972, 189.2172, 189.5763, 189.7206, 189.2426, 189.3977, 189.4292,
         189.3432, 189.3101, 189.2911],
        [189.5798, 189.2997, 189.6588, 189.8035, 189.3259, 189.4798, 189.5108,
         189.4257, 189.3929, 189.3738],
        [189.8015, 189.5236, 189.8816, 190.0260, 189.5499, 189.7031, 189.7341,
         189.6486, 189.6166, 189.5968],
        [189.4825, 189.2025, 189.5622, 189.7060, 189.2297, 189.3817, 189.4136,
         189.3285, 189.2957, 189.2766],
        [189.7941, 189.5161, 189.8741, 190.0185, 189.5419, 189.6957, 189.7269,
         189.6411, 189.6091, 189.5892],
        [189.3947, 189

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.5065, 189.6999, 189.8310, 189.0307, 189.8679, 190.2720, 189.9238,
         189.6203, 190.6090, 189.6038],
        [189.1314, 189.3241, 189.4556, 188.6542, 189.4918, 189.8955, 189.5486,
         189.2446, 190.2325, 189.2273],
        [189.4504, 189.6477, 189.7767, 188.9765, 189.8138, 190.2182, 189.8689,
         189.5671, 190.5541, 189.5490],
        [189.0167, 189.2112, 189.3427, 188.5413, 189.3776, 189.7832, 189.4340,
         189.1315, 190.1193, 189.1142],
        [189.4186, 189.6163, 189.7458, 188.9453, 189.7827, 190.1869, 189.8378,
         189.5347, 190.5228, 189.5174],
        [189.1919, 189.3879, 189.5169, 188.7159, 189.5538, 189.9576, 189.6099,
         189.3081, 190.2939, 189.2887],
        [189.3032, 189.4960, 189.6269, 188.8259, 189.6637, 190.0672, 189.7202,
         189.4169, 190.4044, 189.3992],
        [189.4695, 189.6651, 189.7951, 188.9947, 189.83

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.4865, 189.8837, 189.0701, 188.7340, 189.5715, 189.3890, 188.8663,
         189.2729, 189.6308, 189.2350],
        [189.5026, 189.8980, 189.0844, 188.7497, 189.5855, 189.4029, 188.8790,
         189.2858, 189.6451, 189.2503],
        [189.4967, 189.8923, 189.0766, 188.7425, 189.5798, 189.3960, 188.8717,
         189.2769, 189.6370, 189.2424],
        [189.4884, 189.8849, 189.0693, 188.7343, 189.5722, 189.3904, 188.8663,
         189.2717, 189.6277, 189.2325],
        [189.7381, 190.1350, 189.3202, 188.9855, 189.8226, 189.6402, 189.1171,
         189.5213, 189.8798, 189.4857],
        [189.4109, 189.8078, 188.9913, 188.6563, 189.4957, 189.3100, 188.7865,
         189.1908, 189.5532, 189.1580],
        [189.4911, 189.8878, 189.0738, 188.7383, 189.5757, 189.3923, 188.8694,
         189.2754, 189.6352, 189.2399],
        [189.5338, 189

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.7219, 189.2639, 188.8966, 189.3902, 189.3011, 190.2805, 189.4229,
         189.6918, 189.6488, 189.4644],
        [189.5381, 189.0815, 188.7119, 189.2091, 189.1156, 190.0966, 189.2383,
         189.5089, 189.4644, 189.2796],
        [189.5345, 189.0734, 188.7079, 189.2010, 189.1114, 190.0920, 189.2342,
         189.5025, 189.4618, 189.2753],
        [189.6159, 189.1591, 188.7898, 189.2865, 189.1938, 190.1743, 189.3162,
         189.5868, 189.5424, 189.3577],
        [189.5349, 189.0745, 188.7070, 189.2033, 189.1109, 190.0917, 189.2334,
         189.5049, 189.4617, 189.2752],
        [189.6096, 189.1522, 188.7837, 189.2796, 189.1874, 190.1680, 189.3100,
         189.5800, 189.5364, 189.3514],
        [189.4060, 188.9472, 188.5793, 189.0755, 188.9819, 189.9637, 189.1057,
         189.3753, 189.3329, 189.1465],
        [189.4500, 188

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.4400, 189.3043, 189.9534, 189.0239, 189.6068, 189.5512, 189.2368,
         190.2207, 189.3851, 189.5997],
        [189.6471, 189.5094, 190.1604, 189.2299, 189.8124, 189.7572, 189.4437,
         190.4257, 189.5901, 189.8051],
        [189.3689, 189.2331, 189.8820, 188.9521, 189.5347, 189.4790, 189.1648,
         190.1490, 189.3131, 189.5279],
        [189.4503, 189.3113, 189.9629, 189.0305, 189.6126, 189.5573, 189.2444,
         190.2261, 189.3908, 189.6059],
        [189.4789, 189.3419, 189.9916, 189.0616, 189.6451, 189.5889, 189.2750,
         190.2585, 189.4211, 189.6378],
        [189.5975, 189.4596, 190.1112, 189.1801, 189.7625, 189.7077, 189.3940,
         190.3758, 189.5414, 189.7551],
        [189.3877, 189.2502, 189.9005, 188.9700, 189.5536, 189.4975, 189.1836,
         190.1667, 189.3303, 189.5461],
        [189.4342, 189

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.0336, 189.4163, 189.9513, 188.9994, 189.3496, 188.8693, 189.4968,
         189.6391, 188.6115, 190.0477],
        [189.0291, 189.4117, 189.9455, 188.9957, 189.3479, 188.8647, 189.4926,
         189.6344, 188.6065, 190.0436],
        [188.9371, 189.3213, 189.8554, 188.9034, 189.2535, 188.7731, 189.4012,
         189.5438, 188.5156, 189.9519],
        [189.0338, 189.4220, 189.9522, 189.0023, 189.3526, 188.8716, 189.4998,
         189.6447, 188.6132, 190.0508],
        [189.2631, 189.6477, 190.1812, 189.2297, 189.5807, 189.1001, 189.7271,
         189.8715, 188.8406, 190.2782],
        [188.9402, 189.3265, 189.8571, 188.9086, 189.2610, 188.7768, 189.4057,
         189.5490, 188.5187, 189.9567],
        [188.9632, 189.3509, 189.8816, 188.9316, 189.2831, 188.8005, 189.4294,
         189.5737, 188.5421, 189.9802],
        [189.0461, 189

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.0376, 189.5438, 189.1621, 189.2872, 188.8385, 189.0932, 190.1634,
         189.1295, 189.7845, 189.9005],
        [188.7013, 189.2104, 188.8259, 188.9519, 188.5031, 188.7596, 189.8305,
         188.7953, 189.4491, 189.5662],
        [189.1363, 189.6433, 189.2598, 189.3852, 188.9379, 189.1933, 190.2636,
         189.2282, 189.8806, 189.9971],
        [188.9601, 189.4672, 189.0848, 189.2097, 188.7614, 189.0166, 190.0869,
         189.0528, 189.7069, 189.8230],
        [188.9515, 189.4582, 189.0743, 189.2028, 188.7517, 189.0072, 190.0786,
         189.0420, 189.6970, 189.8141],
        [188.7355, 189.2439, 188.8593, 188.9870, 188.5363, 188.7927, 189.8643,
         188.8282, 189.4826, 189.5998],
        [189.0291, 189.5359, 189.1522, 189.2798, 188.8303, 189.0853, 190.1558,
         189.1199, 189.7755, 189.8928],
        [189.1205, 189

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.3215, 189.2957, 189.2280, 189.4425, 189.2091, 189.0360, 188.9525,
         188.7511, 190.1265, 188.8922],
        [189.4674, 189.4420, 189.3740, 189.5883, 189.3547, 189.1819, 189.0991,
         188.8968, 190.2729, 189.0383],
        [189.5174, 189.4913, 189.4243, 189.6384, 189.4052, 189.2315, 189.1489,
         188.9470, 190.3224, 189.0891],
        [189.4443, 189.4179, 189.3525, 189.5664, 189.3349, 189.1581, 189.0765,
         188.8743, 190.2494, 189.0175],
        [189.2999, 189.2739, 189.2076, 189.4218, 189.1897, 189.0141, 188.9317,
         188.7299, 190.1051, 188.8720],
        [189.3152, 189.2916, 189.2225, 189.4371, 189.2036, 189.0302, 188.9496,
         188.7472, 190.1217, 188.8868],
        [189.4638, 189.4371, 189.3717, 189.5856, 189.3539, 189.1775, 189.0957,
         188.8938, 190.2684, 189.0368],
        [189.4860, 189.4601, 189.3925, 189.6067, 189.37

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.7045, 189.1010, 188.9602, 188.4468, 189.7841, 189.4109, 188.9043,
         189.4176, 189.1806, 189.4761],
        [188.6263, 189.0224, 188.8819, 188.3689, 189.7074, 189.3326, 188.8253,
         189.3399, 189.1035, 189.3981],
        [188.4892, 188.8868, 188.7459, 188.2328, 189.5696, 189.1966, 188.6884,
         189.2034, 188.9661, 189.2612],
        [188.8123, 189.2078, 189.0669, 188.5538, 189.8931, 189.5177, 189.0123,
         189.5252, 189.2897, 189.5847],
        [188.7008, 189.0959, 188.9570, 188.4437, 189.7823, 189.4068, 188.8988,
         189.4141, 189.1770, 189.4710],
        [188.4560, 188.8532, 188.7127, 188.1999, 189.5373, 189.1632, 188.6542,
         189.1701, 188.9333, 189.2282],
        [188.7914, 189.1866, 189.0462, 188.5333, 189.8732, 189.4965, 188.9904,
         189.5045, 189.2693, 189.5638],
        [188.8310, 189

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.1827, 189.6302, 189.8029, 189.1163, 189.2031, 189.7008, 189.4011,
         189.2702, 188.6029, 189.1562],
        [189.1037, 189.5534, 189.7260, 189.0360, 189.1249, 189.6200, 189.3238,
         189.1915, 188.5233, 189.0763],
        [189.0980, 189.5474, 189.7196, 189.0330, 189.1201, 189.6176, 189.3183,
         189.1886, 188.5196, 189.0726],
        [188.7822, 189.2309, 189.4049, 188.7158, 188.8041, 189.2996, 189.0038,
         188.8729, 188.2030, 188.7550],
        [189.1217, 189.5706, 189.7433, 189.0554, 189.1432, 189.6393, 189.3419,
         189.2112, 188.5420, 189.0946],
        [188.8038, 189.2525, 189.4262, 188.7375, 188.8255, 189.3216, 189.0251,
         188.8944, 188.2246, 188.7768],
        [189.0659, 189.5162, 189.6885, 188.9990, 189.0877, 189.5831, 189.2865,
         189.1548, 188.4861, 189.0392],
        [189.2914, 189

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.5229, 189.9897, 189.3171, 188.9751, 189.1671, 189.2001, 189.0945,
         189.3782, 189.8308, 189.0955],
        [189.6542, 190.1226, 189.4488, 189.1087, 189.3006, 189.3351, 189.2270,
         189.5113, 189.9614, 189.2277],
        [189.5646, 190.0307, 189.3587, 189.0163, 189.2088, 189.2414, 189.1361,
         189.4198, 189.8734, 189.1373],
        [189.5589, 190.0268, 189.3539, 189.0129, 189.2066, 189.2374, 189.1322,
         189.4160, 189.8702, 189.1331],
        [189.4205, 189.8901, 189.2160, 188.8762, 189.0683, 189.1009, 188.9942,
         189.2786, 189.7290, 188.9946],
        [189.4253, 189.8921, 189.2201, 188.8775, 189.0702, 189.1020, 188.9974,
         189.2812, 189.7354, 188.9986],
        [189.5401, 190.0082, 189.3347, 188.9942, 189.1865, 189.2192, 189.1128,
         189.3968, 189.8484, 189.1134],
        [189.4370, 189

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.0230, 189.4369, 188.5216, 188.3930, 189.6746, 188.9584, 189.3082,
         188.9183, 189.1578, 189.0134],
        [188.8480, 189.2590, 188.3457, 188.2162, 189.4981, 188.7819, 189.1322,
         188.7418, 188.9822, 188.8360],
        [189.1311, 189.5442, 188.6302, 188.5001, 189.7840, 189.0666, 189.4168,
         189.0255, 189.2661, 189.1209],
        [189.1288, 189.5360, 188.6272, 188.4960, 189.7764, 189.0622, 189.4104,
         189.0213, 189.2611, 189.1147],
        [189.0343, 189.4492, 188.5329, 188.4044, 189.6892, 188.9706, 189.3214,
         188.9297, 189.1700, 189.0253],
        [188.9840, 189.3940, 188.4822, 188.3520, 189.6334, 188.9179, 189.2674,
         188.8775, 189.1176, 188.9716],
        [189.1643, 189.5791, 188.6635, 188.5344, 189.8189, 189.1007, 189.4510,
         189.0596, 189.2998, 189.1553],
        [188.9976, 189

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.3058, 189.1217, 189.3383, 188.7109, 189.2538, 188.8419, 189.1861,
         188.4528, 189.2593, 189.3987],
        [189.4775, 189.2915, 189.5086, 188.8829, 189.4235, 189.0125, 189.3574,
         188.6242, 189.4290, 189.5699],
        [189.2046, 189.0195, 189.2363, 188.6085, 189.1519, 188.7395, 189.0844,
         188.3510, 189.1584, 189.2970],
        [189.1846, 189.0001, 189.2162, 188.5885, 189.1314, 188.7198, 189.0636,
         188.3311, 189.1370, 189.2774],
        [189.2545, 189.0682, 189.2843, 188.6577, 189.1998, 188.7880, 189.1329,
         188.4002, 189.2071, 189.3464],
        [189.2798, 189.0936, 189.3096, 188.6833, 189.2251, 188.8135, 189.1582,
         188.4257, 189.2325, 189.3718],
        [189.1736, 188.9877, 189.2033, 188.5764, 189.1181, 188.7073, 189.0510,
         188.3194, 189.1238, 189.2662],
        [189.2916, 189.1055, 189.3233, 188.6961, 189.23

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.6060, 188.7951, 189.1382, 189.2907, 189.2054, 189.0815, 188.4078,
         189.1549, 190.0584, 188.9855],
        [189.6002, 188.7873, 189.1319, 189.2836, 189.1989, 189.0730, 188.3997,
         189.1473, 190.0518, 188.9787],
        [189.4924, 188.6812, 189.0248, 189.1773, 189.0919, 188.9668, 188.2935,
         189.0421, 189.9440, 188.8722],
        [189.5269, 188.7133, 189.0582, 189.2103, 189.1257, 188.9982, 188.3250,
         189.0746, 189.9776, 188.9049],
        [189.5749, 188.7618, 189.1067, 189.2583, 189.1735, 189.0475, 188.3743,
         189.1220, 190.0264, 188.9534],
        [189.5107, 188.7003, 189.0424, 189.1961, 189.1107, 188.9858, 188.3121,
         189.0615, 189.9623, 188.8902],
        [189.6141, 188.8002, 189.1463, 189.2973, 189.2126, 189.0852, 188.4120,
         189.1613, 190.0644, 188.9933],
        [189.4584, 188

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.5344, 189.2347, 189.2653, 189.6659, 188.9707, 188.4345, 188.4282,
         188.3522, 188.6303, 189.0572],
        [188.4671, 189.1674, 189.1975, 189.6006, 188.9034, 188.3640, 188.3622,
         188.2835, 188.5627, 188.9867],
        [188.6891, 189.3867, 189.4178, 189.8192, 189.1236, 188.5865, 188.5812,
         188.5028, 188.7825, 189.2068],
        [188.4588, 189.1588, 189.1895, 189.5904, 188.8948, 188.3581, 188.3525,
         188.2764, 188.5549, 188.9811],
        [188.3865, 189.0864, 189.1166, 189.5193, 188.8222, 188.2834, 188.2808,
         188.2030, 188.4825, 188.9066],
        [188.4966, 189.1961, 189.2268, 189.6279, 188.9320, 188.3953, 188.3900,
         188.3128, 188.5919, 189.0173],
        [188.2809, 188.9839, 189.0135, 189.4162, 188.7187, 188.1797, 188.1779,
         188.1009, 188.3789, 188.8048],
        [188.5518, 189

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.0206, 189.0180, 189.2839, 190.0186, 189.1072, 189.7544, 189.0386,
         189.1783, 189.3486, 188.9831],
        [188.7756, 188.7737, 189.0419, 189.7751, 188.8627, 189.5103, 188.7948,
         188.9323, 189.1052, 188.7420],
        [188.8183, 188.8163, 189.0862, 189.8171, 188.9043, 189.5540, 188.8380,
         188.9757, 189.1485, 188.7850],
        [188.8965, 188.8938, 189.1608, 189.8938, 188.9823, 189.6315, 188.9149,
         189.0524, 189.2258, 188.8600],
        [189.1110, 189.1083, 189.3769, 190.1090, 189.1961, 189.8446, 189.1293,
         189.2686, 189.4383, 189.0753],
        [188.8324, 188.8309, 189.0983, 189.8327, 188.9203, 189.5666, 188.8515,
         188.9911, 189.1615, 188.7984],
        [188.9427, 188.9403, 189.2074, 189.9412, 189.0292, 189.6768, 188.9611,
         189.1002, 189.2711, 188.9067],
        [189.1395, 189

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.3485, 189.2082, 188.8729, 188.9956, 189.8727, 188.8953, 189.0340,
         188.2937, 189.0275, 189.6109],
        [188.5045, 189.3657, 189.0312, 189.1540, 190.0298, 189.0523, 189.1907,
         188.4508, 189.1854, 189.7662],
        [188.4464, 189.3067, 188.9722, 189.0982, 189.9722, 188.9932, 189.1297,
         188.3940, 189.1295, 189.7065],
        [188.1541, 189.0143, 188.6786, 188.8044, 189.6802, 188.7006, 188.8394,
         188.1011, 188.8362, 189.4155],
        [188.3407, 189.1997, 188.8648, 188.9889, 189.8645, 188.8870, 189.0234,
         188.2865, 189.0205, 189.6024],
        [188.2948, 189.1559, 188.8214, 188.9470, 189.8211, 188.8423, 188.9791,
         188.2429, 188.9782, 189.5558],
        [188.5832, 189.4435, 189.1094, 189.2332, 190.1078, 189.1305, 189.2669,
         188.5297, 189.2645, 189.8440],
        [188.3686, 189.2296, 188.8954, 189.0177, 189.89

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.9090, 188.6745, 189.0366, 189.3837, 189.5101, 189.2372, 189.8479,
         188.8740, 189.2932, 189.0684],
        [189.6704, 188.4349, 188.7971, 189.1456, 189.2714, 188.9992, 189.6092,
         188.6342, 189.0554, 188.8297],
        [189.6133, 188.3775, 188.7396, 189.0882, 189.2146, 188.9418, 189.5519,
         188.5764, 188.9981, 188.7725],
        [189.7941, 188.5613, 188.9222, 189.2717, 189.3994, 189.1232, 189.7341,
         188.7592, 189.1797, 188.9543],
        [189.4779, 188.2416, 188.6037, 188.9531, 189.0796, 188.8065, 189.4164,
         188.4402, 188.8630, 188.6371],
        [189.8171, 188.5833, 188.9452, 189.2934, 189.4193, 189.1471, 189.7566,
         188.7820, 189.2033, 188.9766],
        [189.8316, 188.5987, 188.9598, 189.3087, 189.4361, 189.1608, 189.7714,
         188.7968, 189.2171, 188.9916],
        [189.7554, 188.5215, 188.8834, 189.2322, 189.35

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.8682, 188.9091, 189.2044, 189.1109, 189.8123, 189.0769, 189.0044,
         188.9410, 188.8343, 188.9701],
        [188.8300, 188.8698, 189.1675, 189.0735, 189.7754, 189.0384, 188.9680,
         188.9037, 188.7953, 188.9339],
        [188.7211, 188.7639, 189.0594, 188.9661, 189.6685, 188.9299, 188.8610,
         188.7957, 188.6872, 188.8257],
        [188.7536, 188.7957, 189.0917, 188.9983, 189.7007, 188.9623, 188.8931,
         188.8281, 188.7195, 188.8582],
        [188.9354, 188.9826, 189.2738, 189.1822, 189.8837, 189.1452, 189.0772,
         189.0115, 188.9032, 189.0401],
        [188.8960, 188.9379, 189.2332, 189.1400, 189.8415, 189.1048, 189.0348,
         188.9700, 188.8618, 188.9991],
        [188.6816, 188.7258, 189.0188, 188.9261, 189.6284, 188.8906, 188.8195,
         188.7551, 188.6490, 188.7849],
        [188.6404, 188

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.9928, 189.1045, 188.6754, 188.9253, 188.4521, 189.0160, 188.7451,
         188.5816, 188.7732, 188.8441],
        [189.1254, 189.2357, 188.8080, 189.0575, 188.5851, 189.1478, 188.8765,
         188.7149, 188.9062, 188.9756],
        [189.0671, 189.1773, 188.7499, 189.0009, 188.5284, 189.0899, 188.8214,
         188.6573, 188.8496, 188.9180],
        [188.9112, 189.0222, 188.5940, 188.8435, 188.3709, 188.9337, 188.6637,
         188.5010, 188.6921, 188.7614],
        [188.9244, 189.0355, 188.6066, 188.8560, 188.3829, 188.9467, 188.6750,
         188.5135, 188.7045, 188.7741],
        [189.0715, 189.1823, 188.7532, 189.0053, 188.5312, 189.0946, 188.8240,
         188.6601, 188.8532, 188.9227],
        [189.2518, 189.3619, 188.9347, 189.1858, 188.7134, 189.2749, 189.0057,
         188.8417, 189.0343, 189.1033],
        [189.0182, 189

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.1300, 189.7630, 189.8327, 188.7113, 188.8366, 189.0744, 189.2806,
         188.7444, 188.9743, 189.3815],
        [189.1264, 189.7590, 189.8309, 188.7085, 188.8339, 189.0723, 189.2762,
         188.7428, 188.9717, 189.3805],
        [189.0520, 189.6848, 189.7538, 188.6329, 188.7578, 188.9946, 189.2014,
         188.6676, 188.8958, 189.3034],
        [189.1206, 189.7538, 189.8256, 188.7038, 188.8287, 189.0666, 189.2697,
         188.7387, 188.9670, 189.3760],
        [188.9708, 189.6030, 189.6739, 188.5518, 188.6772, 188.9147, 189.1196,
         188.5877, 188.8150, 189.2242],
        [188.9040, 189.5355, 189.6061, 188.4836, 188.6094, 188.8469, 189.0527,
         188.5195, 188.7468, 189.1560],
        [188.8568, 189.4894, 189.5591, 188.4376, 188.5634, 188.8008, 189.0057,
         188.4721, 188.7010, 189.1092],
        [189.0450, 189

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.9702, 188.9073, 188.7180, 189.0753, 189.7462, 189.6579, 188.6949,
         188.4479, 188.9573, 189.0452],
        [189.2057, 189.1396, 188.9527, 189.3093, 189.9788, 189.8920, 188.9291,
         188.6808, 189.1913, 189.2784],
        [189.2132, 189.1470, 188.9608, 189.3168, 189.9867, 189.8996, 188.9368,
         188.6883, 189.1989, 189.2863],
        [189.0861, 189.0216, 188.8316, 189.1906, 189.8595, 189.7732, 188.8097,
         188.5617, 189.0722, 189.1597],
        [188.9645, 188.8963, 188.7095, 189.0686, 189.7356, 189.6504, 188.6855,
         188.4379, 188.9483, 189.0363],
        [189.1232, 189.0596, 188.8685, 189.2279, 189.8971, 189.8107, 188.8473,
         188.5991, 189.1098, 189.1977],
        [188.9590, 188.8927, 188.7032, 189.0637, 189.7304, 189.6452, 188.6808,
         188.4335, 188.9436, 189.0307],
        [189.0510, 188

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.6444, 188.9456, 189.7046, 188.9694, 188.7955, 188.3979, 189.1035,
         188.8154, 189.0039, 188.7945],
        [189.4698, 188.7701, 189.5305, 188.7942, 188.6201, 188.2236, 188.9288,
         188.6404, 188.8296, 188.6190],
        [189.7209, 189.0203, 189.7820, 189.0444, 188.8726, 188.4769, 189.1803,
         188.8918, 189.0814, 188.8698],
        [189.6794, 188.9812, 189.7396, 189.0039, 188.8302, 188.4342, 189.1382,
         188.8509, 189.0391, 188.8300],
        [189.6842, 188.9835, 189.7454, 189.0077, 188.8361, 188.4394, 189.1440,
         188.8549, 189.0446, 188.8327],
        [189.5730, 188.8739, 189.6331, 188.8981, 188.7235, 188.3268, 189.0317,
         188.7438, 188.9326, 188.7230],
        [189.3358, 188.6375, 189.3968, 188.6604, 188.4858, 188.0880, 188.7952,
         188.5069, 188.6951, 188.4855],
        [189.5156, 188

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.1321, 189.0743, 189.1909, 188.8136, 189.7142, 188.8033, 189.5562,
         189.1062, 188.3078, 188.9630],
        [189.0537, 188.9945, 189.1115, 188.7342, 189.6350, 188.7240, 189.4765,
         189.0271, 188.2277, 188.8841],
        [189.0547, 189.0013, 189.1158, 188.7352, 189.6368, 188.7273, 189.4822,
         189.0300, 188.2325, 188.8844],
        [189.0625, 189.0058, 189.1212, 188.7394, 189.6433, 188.7341, 189.4881,
         189.0367, 188.2366, 188.8904],
        [189.1694, 189.1110, 189.2278, 188.8523, 189.7518, 188.8400, 189.5924,
         189.1431, 188.3452, 189.0011],
        [189.1141, 189.0603, 189.1750, 188.7946, 189.6963, 188.7868, 189.5416,
         189.0893, 188.2920, 188.9440],
        [188.9258, 188.8721, 188.9862, 188.6050, 189.5073, 188.5976, 189.3526,
         188.9006, 188.1019, 188.7545],
        [188.9556, 188

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.5718, 188.9884, 189.7207, 188.9275, 189.2706, 188.6356, 188.6205,
         189.1697, 188.6014, 188.6752],
        [188.7431, 189.1583, 189.8908, 189.0943, 189.4418, 188.8059, 188.7888,
         189.3363, 188.7718, 188.8456],
        [188.5011, 188.9176, 189.6498, 188.8565, 189.1999, 188.5649, 188.5496,
         189.0983, 188.5304, 188.6043],
        [188.6225, 189.0372, 189.7701, 188.9735, 189.3211, 188.6847, 188.6674,
         189.2155, 188.6507, 188.7246],
        [188.6010, 189.0163, 189.7499, 188.9559, 189.2991, 188.6626, 188.6478,
         189.1986, 188.6299, 188.7035],
        [188.4973, 188.9122, 189.6456, 188.8493, 189.1965, 188.5597, 188.5426,
         189.0910, 188.5255, 188.5998],
        [188.6665, 189.0838, 189.8168, 189.0235, 189.3662, 188.7312, 188.7163,
         189.2656, 188.6967, 188.7708],
        [188.3066, 188

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.4004, 188.7865, 189.1467, 189.0884, 188.3564, 188.6173, 188.6325,
         189.3186, 188.3801, 188.7070],
        [188.5278, 188.9145, 189.2728, 189.2153, 188.4853, 188.7463, 188.7601,
         189.4445, 188.5077, 188.8369],
        [188.3409, 188.7282, 189.0872, 189.0283, 188.2972, 188.5583, 188.5728,
         189.2588, 188.3201, 188.6485],
        [188.3601, 188.7476, 189.1043, 189.0485, 188.3172, 188.5769, 188.5938,
         189.2760, 188.3401, 188.6683],
        [188.1739, 188.5613, 188.9194, 188.8613, 188.1296, 188.3903, 188.4058,
         189.0914, 188.1523, 188.4809],
        [188.3088, 188.6960, 189.0518, 188.9966, 188.2662, 188.5260, 188.5415,
         189.2242, 188.2879, 188.6178],
        [188.0225, 188.4102, 188.7654, 188.7087, 187.9791, 188.2394, 188.2532,
         188.9383, 187.9989, 188.3313],
        [188.3369, 188.7258, 189.0817, 189.0242, 188.29

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.0750, 188.7419, 188.8468, 188.9000, 188.7775, 188.4366, 189.1122,
         188.7885, 188.8719, 189.0481],
        [188.9692, 188.6385, 188.7410, 188.7972, 188.6710, 188.3292, 189.0063,
         188.6802, 188.7667, 188.9457],
        [188.8311, 188.4998, 188.6022, 188.6580, 188.5326, 188.1917, 188.8682,
         188.5432, 188.6278, 188.8064],
        [188.6938, 188.3611, 188.4649, 188.5193, 188.3954, 188.0545, 188.7300,
         188.4066, 188.4899, 188.6685],
        [188.7599, 188.4281, 188.5314, 188.5864, 188.4611, 188.1187, 188.7959,
         188.4698, 188.5562, 188.7360],
        [188.7406, 188.4075, 188.5121, 188.5648, 188.4411, 188.0981, 188.7758,
         188.4492, 188.5354, 188.7151],
        [189.0437, 188.7110, 188.8149, 188.8683, 188.7452, 188.4044, 189.0808,
         188.7561, 188.8393, 189.0164],
        [188.9332, 188.6002, 188.7047, 188.7572, 188.63

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.1991, 189.2629, 188.5066, 188.3393, 189.0182, 188.6466, 188.4168,
         188.4372, 188.6010, 188.4763],
        [188.4911, 189.5539, 188.7974, 188.6311, 189.3112, 188.9381, 188.7093,
         188.7269, 188.8926, 188.7683],
        [188.5361, 189.5995, 188.8432, 188.6749, 189.3554, 188.9829, 188.7531,
         188.7722, 188.9375, 188.8126],
        [188.3347, 189.3978, 188.6422, 188.4744, 189.1555, 188.7820, 188.5536,
         188.5714, 188.7367, 188.6116],
        [188.4098, 189.4724, 188.7155, 188.5480, 189.2305, 188.8563, 188.6272,
         188.6436, 188.8104, 188.6849],
        [188.1660, 189.2293, 188.4725, 188.3058, 188.9860, 188.6133, 188.3840,
         188.4024, 188.5673, 188.4423],
        [188.4015, 189.4650, 188.7083, 188.5406, 189.2207, 188.8485, 188.6184,
         188.6377, 188.8028, 188.6779],
        [188.3674, 189

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.9847, 189.1017, 188.5891, 188.8166, 188.5591, 188.7625, 189.5658,
         188.6765, 188.6587, 188.5357],
        [188.8903, 189.0074, 188.4933, 188.7192, 188.4633, 188.6659, 189.4706,
         188.5813, 188.5632, 188.4392],
        [188.5717, 188.6856, 188.1738, 188.4012, 188.1438, 188.3466, 189.1513,
         188.2612, 188.2418, 188.1181],
        [189.1033, 189.2217, 188.7090, 188.9345, 188.6778, 188.8810, 189.6845,
         188.7961, 188.7783, 188.6555],
        [188.9263, 189.0433, 188.5332, 188.7579, 188.5010, 188.7033, 189.5076,
         188.6190, 188.6005, 188.4779],
        [188.7674, 188.8829, 188.3727, 188.5978, 188.3419, 188.5428, 189.3485,
         188.4586, 188.4406, 188.3164],
        [189.0220, 189.1387, 188.6265, 188.8546, 188.5969, 188.8002, 189.6034,
         188.7138, 188.6964, 188.5732],
        [188.9082, 189

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.7498, 188.4661, 189.4399, 188.4152, 188.7112, 188.3482, 188.6153,
         188.6716, 188.4864, 188.5982],
        [188.8538, 188.5703, 189.5456, 188.5205, 188.8154, 188.4528, 188.7193,
         188.7745, 188.5910, 188.7046],
        [188.8073, 188.5242, 189.4982, 188.4744, 188.7688, 188.4064, 188.6722,
         188.7297, 188.5446, 188.6572],
        [189.0571, 188.7744, 189.7490, 188.7254, 189.0207, 188.6560, 188.9214,
         188.9781, 188.7955, 188.9088],
        [188.7304, 188.4471, 189.4213, 188.3976, 188.6940, 188.3282, 188.5952,
         188.6519, 188.4673, 188.5809],
        [188.8728, 188.5896, 189.5656, 188.5413, 188.8363, 188.4714, 188.7378,
         188.7930, 188.6103, 188.7257],
        [188.6821, 188.3987, 189.3743, 188.3512, 188.6458, 188.2802, 188.5466,
         188.6030, 188.4189, 188.5349],
        [188.8162, 188

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.6581, 188.7831, 188.7782, 188.5426, 188.4314, 188.6681, 188.1290,
         188.6228, 189.5147, 188.6015],
        [188.8456, 188.9708, 188.9635, 188.7297, 188.6184, 188.8548, 188.3163,
         188.8102, 189.7008, 188.7888],
        [188.8043, 188.9290, 188.9225, 188.6890, 188.5761, 188.8133, 188.2751,
         188.7687, 189.6595, 188.7479],
        [188.6096, 188.7342, 188.7283, 188.4949, 188.3804, 188.6184, 188.0801,
         188.5741, 189.4642, 188.5546],
        [188.6003, 188.7248, 188.7205, 188.4843, 188.3725, 188.6101, 188.0701,
         188.5641, 189.4558, 188.5431],
        [188.6712, 188.7960, 188.7914, 188.5563, 188.4438, 188.6809, 188.1424,
         188.6360, 189.5280, 188.6153],
        [188.6031, 188.7278, 188.7225, 188.4874, 188.3751, 188.6125, 188.0731,
         188.5673, 189.4581, 188.5467],
        [188.7996, 188

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.5260, 188.8328, 188.3746, 188.6043, 188.2385, 188.3852, 188.0936,
         188.3708, 188.2562, 188.1634],
        [188.7747, 189.0793, 188.6227, 188.8514, 188.4860, 188.6327, 188.3425,
         188.6192, 188.5049, 188.4121],
        [188.9925, 189.2994, 188.8417, 189.0707, 188.7058, 188.8507, 188.5631,
         188.8369, 188.7247, 188.6314],
        [188.7966, 189.1026, 188.6454, 188.8757, 188.5094, 188.6550, 188.3661,
         188.6433, 188.5280, 188.4352],
        [188.7050, 189.0116, 188.5544, 188.7842, 188.4180, 188.5641, 188.2743,
         188.5511, 188.4363, 188.3435],
        [188.6248, 188.9315, 188.4732, 188.7043, 188.3378, 188.4834, 188.1934,
         188.4713, 188.3558, 188.2626],
        [188.7578, 189.0656, 188.6060, 188.8359, 188.4711, 188.6159, 188.3267,
         188.6011, 188.4890, 188.3955],
        [188.6524, 188

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.2058, 188.6302, 188.6720, 188.4062, 188.7234, 188.8322, 188.7444,
         188.5064, 189.3517, 189.2376],
        [188.0498, 188.4730, 188.5161, 188.2498, 188.5671, 188.6764, 188.5882,
         188.3495, 189.1954, 189.0811],
        [187.9944, 188.4162, 188.4602, 188.1942, 188.5107, 188.6201, 188.5328,
         188.2933, 189.1398, 189.0242],
        [188.1850, 188.6084, 188.6514, 188.3853, 188.7012, 188.8098, 188.7238,
         188.4848, 189.3306, 189.2146],
        [188.1093, 188.5340, 188.5764, 188.3109, 188.6271, 188.7371, 188.6473,
         188.4106, 189.2556, 189.1418],
        [188.2423, 188.6680, 188.7099, 188.4433, 188.7595, 188.8681, 188.7809,
         188.5429, 189.3879, 189.2733],
        [188.1637, 188.5880, 188.6310, 188.3638, 188.6800, 188.7884, 188.7024,
         188.4628, 189.3088, 189.1931],
        [188.2430, 188

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.5786, 188.4813, 188.8726, 188.4004, 188.0612, 188.3820, 188.6572,
         188.6631, 188.6587, 189.0930],
        [188.4025, 188.3064, 188.6971, 188.2236, 187.8864, 188.2052, 188.4822,
         188.4874, 188.4821, 188.9159],
        [188.3179, 188.2210, 188.6117, 188.1383, 187.8015, 188.1203, 188.3968,
         188.4026, 188.3976, 188.8314],
        [188.4530, 188.3557, 188.7464, 188.2740, 187.9364, 188.2557, 188.5321,
         188.5378, 188.5325, 188.9666],
        [188.6749, 188.5784, 188.9686, 188.4971, 188.1576, 188.4788, 188.7520,
         188.7594, 188.7550, 189.1888],
        [188.5426, 188.4485, 188.8386, 188.3647, 188.0265, 188.3464, 188.6215,
         188.6276, 188.6227, 189.0558],
        [188.5361, 188.4433, 188.8334, 188.3584, 188.0201, 188.3402, 188.6154,
         188.6212, 188.6167, 189.0495],
        [188.5307, 188

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.5854, 189.1297, 188.5484, 188.4572, 188.3039, 188.4352, 188.5249,
         188.2019, 188.3060, 188.3773],
        [188.7240, 189.2685, 188.6878, 188.5975, 188.4424, 188.5768, 188.6624,
         188.3402, 188.4445, 188.5177],
        [188.5636, 189.1073, 188.5276, 188.4360, 188.2816, 188.4142, 188.5024,
         188.1782, 188.2840, 188.3551],
        [188.7589, 189.3032, 188.7231, 188.6318, 188.4772, 188.6098, 188.6977,
         188.3754, 188.4797, 188.5513],
        [188.6861, 189.2309, 188.6491, 188.5589, 188.4049, 188.5379, 188.6250,
         188.3041, 188.4067, 188.4796],
        [188.5428, 189.0869, 188.5065, 188.4160, 188.2607, 188.3952, 188.4812,
         188.1573, 188.2628, 188.3356],
        [188.7094, 189.2534, 188.6733, 188.5817, 188.4277, 188.5596, 188.6485,
         188.3247, 188.4300, 188.5015],
        [188.6348, 189.1792, 188.5988, 188.5085, 188.35

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.5060, 188.2668, 187.9024, 188.3199, 188.2448, 189.2888, 188.3909,
         187.9545, 187.9781, 187.7226],
        [188.7786, 188.5408, 188.1758, 188.5924, 188.5182, 189.5623, 188.6650,
         188.2278, 188.2524, 187.9959],
        [188.5412, 188.3034, 187.9377, 188.3545, 188.2804, 189.3260, 188.4276,
         187.9913, 188.0142, 187.7590],
        [188.4835, 188.2461, 187.8796, 188.2972, 188.2226, 189.2681, 188.3699,
         187.9330, 187.9566, 187.7017],
        [188.5237, 188.2836, 187.9205, 188.3372, 188.2624, 189.3063, 188.4082,
         187.9727, 187.9953, 187.7397],
        [188.5286, 188.2899, 187.9243, 188.3421, 188.2673, 189.3117, 188.4141,
         187.9772, 188.0010, 187.7455],
        [188.4576, 188.2179, 187.8541, 188.2699, 188.1965, 189.2421, 188.3434,
         187.9084, 187.9294, 187.6744],
        [188.3886, 188

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.1717, 188.0894, 188.1803, 188.1374, 188.3219, 188.1433, 189.2349,
         188.1192, 187.5991, 188.1941],
        [188.2861, 188.2049, 188.2964, 188.2519, 188.4362, 188.2571, 189.3486,
         188.2344, 187.7133, 188.3085],
        [188.3436, 188.2616, 188.3509, 188.3091, 188.4957, 188.3145, 189.4078,
         188.2918, 187.7702, 188.3672],
        [188.3815, 188.2996, 188.3904, 188.3475, 188.5321, 188.3533, 189.4445,
         188.3297, 187.8090, 188.4041],
        [188.2010, 188.1203, 188.2117, 188.1659, 188.3503, 188.1711, 189.2634,
         188.1490, 187.6277, 188.2231],
        [188.5415, 188.4602, 188.5500, 188.5060, 188.6919, 188.5123, 189.6046,
         188.4899, 187.9680, 188.5643],
        [188.2153, 188.1343, 188.2258, 188.1810, 188.3654, 188.1859, 189.2778,
         188.1635, 187.6423, 188.2377],
        [188.2862, 188

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[189.3300, 188.2674, 188.7336, 188.6876, 188.5940, 188.0314, 188.7274,
         188.4460, 188.4242, 188.8096],
        [189.1812, 188.1218, 188.5851, 188.5420, 188.4442, 187.8807, 188.5790,
         188.2983, 188.2779, 188.6640],
        [188.9452, 187.8827, 188.3488, 188.3046, 188.2084, 187.6430, 188.3426,
         188.0616, 188.0403, 188.4265],
        [189.0919, 188.0299, 188.4956, 188.4527, 188.3549, 187.7912, 188.4897,
         188.2092, 188.1883, 188.5736],
        [189.0435, 187.9832, 188.4475, 188.4046, 188.3064, 187.7419, 188.4415,
         188.1608, 188.1403, 188.5264],
        [189.1813, 188.1211, 188.5850, 188.5415, 188.4449, 187.8815, 188.5792,
         188.2982, 188.2775, 188.6635],
        [189.1120, 188.0475, 188.5158, 188.4700, 188.3751, 187.8111, 188.5092,
         188.2281, 188.2061, 188.5913],
        [189.1239, 188

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.2614, 188.2893, 188.1596, 188.2081, 187.9806, 188.8681, 187.9453,
         188.2411, 188.1196, 187.5638],
        [188.2324, 188.2625, 188.1308, 188.1801, 187.9530, 188.8390, 187.9173,
         188.2121, 188.0906, 187.5362],
        [188.2859, 188.3153, 188.1836, 188.2328, 188.0062, 188.8916, 187.9703,
         188.2655, 188.1434, 187.5889],
        [188.1749, 188.2019, 188.0722, 188.1212, 187.8930, 188.7811, 187.8585,
         188.1547, 188.0320, 187.4763],
        [188.3430, 188.3739, 188.2410, 188.2905, 188.0647, 188.9485, 188.0280,
         188.3224, 188.2005, 187.6470],
        [188.2689, 188.2975, 188.1653, 188.2152, 187.9882, 188.8736, 187.9530,
         188.2482, 188.1242, 187.5708],
        [188.3479, 188.3740, 188.2447, 188.2941, 188.0650, 188.9543, 188.0311,
         188.3268, 188.2032, 187.6489],
        [188.2370, 188

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.0654, 188.0672, 188.1888, 188.2233, 188.5794, 188.1058, 188.0002,
         187.9437, 188.1546, 188.3681],
        [188.3095, 188.3119, 188.4297, 188.4691, 188.8208, 188.3501, 188.2429,
         188.1888, 188.3987, 188.6118],
        [187.9998, 188.0011, 188.1218, 188.1596, 188.5123, 188.0398, 187.9339,
         187.8780, 188.0900, 188.3029],
        [188.4271, 188.4299, 188.5492, 188.5858, 188.9393, 188.4689, 188.3613,
         188.3077, 188.5147, 188.7298],
        [187.9147, 187.9158, 188.0364, 188.0745, 188.4269, 187.9546, 187.8485,
         187.7926, 188.0048, 188.2180],
        [188.0500, 188.0516, 188.1710, 188.2100, 188.5616, 188.0902, 187.9835,
         187.9285, 188.1397, 188.3531],
        [187.8932, 187.8942, 188.0146, 188.0532, 188.4050, 187.9330, 187.8269,
         187.7709, 187.9834, 188.1967],
        [188.1918, 188.1941, 188.3138, 188.3502, 188.70

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.2633, 188.4849, 188.5019, 187.7714, 188.2822, 188.4113, 187.9445,
         188.5592, 187.9641, 188.5641],
        [188.1656, 188.3865, 188.4056, 187.6742, 188.1844, 188.3133, 187.8463,
         188.4616, 187.8677, 188.4673],
        [188.2007, 188.4221, 188.4406, 187.7092, 188.2198, 188.3484, 187.8816,
         188.4970, 187.9028, 188.5026],
        [188.1310, 188.3525, 188.3704, 187.6384, 188.1497, 188.2792, 187.8109,
         188.4274, 187.8326, 188.4325],
        [188.2026, 188.4258, 188.4422, 187.7101, 188.2226, 188.3506, 187.8840,
         188.4998, 187.9048, 188.5050],
        [188.0659, 188.2886, 188.3060, 187.5735, 188.0856, 188.2140, 187.7469,
         188.3632, 187.7686, 188.3683],
        [188.0782, 188.3016, 188.3184, 187.5850, 188.0981, 188.2264, 187.7587,
         188.3759, 187.7810, 188.3811],
        [188.3527, 188

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.3286, 188.0134, 188.1983, 188.0913, 188.2300, 188.3175, 188.1323,
         187.6110, 188.3350, 189.1709],
        [188.4026, 188.0876, 188.2731, 188.1666, 188.3081, 188.3925, 188.2070,
         187.6846, 188.4075, 189.2450],
        [188.5514, 188.2359, 188.4206, 188.3150, 188.4541, 188.5410, 188.3557,
         187.8332, 188.5559, 189.3932],
        [188.2579, 187.9420, 188.1275, 188.0199, 188.1577, 188.2466, 188.0608,
         187.5395, 188.2631, 189.0987],
        [188.1527, 187.8376, 188.0239, 187.9148, 188.0553, 188.1419, 187.9557,
         187.4341, 188.1577, 188.9937],
        [188.3157, 187.9996, 188.1844, 188.0787, 188.2163, 188.3042, 188.1192,
         187.5977, 188.3214, 189.1575],
        [188.2925, 187.9770, 188.1625, 188.0566, 188.1972, 188.2818, 188.0966,
         187.5744, 188.2976, 189.1346],
        [188.0999, 187

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.3098, 188.3606, 188.6736, 188.5072, 188.2744, 187.8497, 188.6371,
         188.2584, 188.8518, 188.6405],
        [188.1053, 188.1569, 188.4699, 188.3056, 188.0706, 187.6447, 188.4338,
         188.0560, 188.6480, 188.4370],
        [188.2217, 188.2744, 188.5882, 188.4234, 188.1874, 187.7635, 188.5516,
         188.1740, 188.7637, 188.5548],
        [187.9361, 187.9854, 188.2989, 188.1329, 187.9001, 187.4738, 188.2631,
         187.8814, 188.4784, 188.2665],
        [187.7667, 187.8188, 188.1320, 187.9691, 187.7321, 187.3045, 188.0969,
         187.7173, 188.3094, 188.0995],
        [188.0325, 188.0840, 188.3978, 188.2331, 187.9977, 187.5729, 188.3614,
         187.9825, 188.5747, 188.3648],
        [188.2540, 188.3065, 188.6197, 188.4552, 188.2197, 187.7953, 188.5830,
         188.2066, 188.7964, 188.5863],
        [188.0547, 188

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.4968, 187.9193, 188.2595, 188.4075, 188.6237, 188.0916, 188.1985,
         188.7218, 187.5893, 187.9561],
        [188.4647, 187.8886, 188.2285, 188.3779, 188.5910, 188.0606, 188.1683,
         188.6926, 187.5596, 187.9229],
        [188.3592, 187.7826, 188.1215, 188.2707, 188.4852, 187.9541, 188.0625,
         188.5850, 187.4521, 187.8179],
        [188.6446, 188.0689, 188.4077, 188.5564, 188.7698, 188.2404, 188.3479,
         188.8717, 187.7390, 188.1040],
        [188.1790, 187.6027, 187.9417, 188.0917, 188.3047, 187.7744, 187.8832,
         188.4059, 187.2722, 187.6368],
        [188.3677, 187.7920, 188.1313, 188.2811, 188.4925, 187.9641, 188.0719,
         188.5961, 187.4620, 187.8256],
        [188.6630, 188.0865, 188.4265, 188.5744, 188.7881, 188.2591, 188.3649,
         188.8897, 187.7564, 188.1223],
        [188.4728, 187

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.0126, 187.9986, 188.1145, 188.0942, 188.2555, 188.7191, 188.8327,
         187.9602, 187.9898, 187.9548],
        [188.1106, 188.0939, 188.2125, 188.1921, 188.3532, 188.8180, 188.9299,
         188.0568, 188.0869, 188.0513],
        [187.9886, 187.9723, 188.0900, 188.0697, 188.2312, 188.6956, 188.8080,
         187.9353, 187.9650, 187.9299],
        [187.9626, 187.9514, 188.0659, 188.0449, 188.2039, 188.6701, 188.7835,
         187.9105, 187.9413, 187.9042],
        [187.9808, 187.9650, 188.0825, 188.0624, 188.2236, 188.6879, 188.8002,
         187.9270, 187.9573, 187.9221],
        [187.9944, 187.9831, 188.0977, 188.0770, 188.2365, 188.7012, 188.8150,
         187.9416, 187.9725, 187.9361],
        [188.1414, 188.1274, 188.2437, 188.2228, 188.3838, 188.8478, 188.9614,
         188.0891, 188.1188, 188.0829],
        [188.3076, 188

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.4799, 188.1874, 187.9547, 188.4214, 188.1620, 188.1410, 187.9391,
         188.0387, 188.0050, 187.9156],
        [188.3532, 188.0646, 187.8283, 188.2976, 188.0349, 188.0146, 187.8135,
         187.9123, 187.8813, 187.7898],
        [188.4446, 188.1573, 187.9224, 188.3899, 188.1285, 188.1078, 187.9061,
         188.0048, 187.9730, 187.8832],
        [188.4971, 188.2095, 187.9734, 188.4419, 188.1796, 188.1589, 187.9580,
         188.0571, 188.0258, 187.9349],
        [188.3823, 188.0913, 187.8565, 188.3254, 188.0638, 188.0434, 187.8414,
         187.9407, 187.9088, 187.8182],
        [188.3561, 188.0682, 187.8343, 188.3002, 188.0397, 188.0187, 187.8184,
         187.9167, 187.8841, 187.7934],
        [188.5158, 188.2282, 187.9929, 188.4609, 188.1991, 188.1784, 187.9768,
         188.0760, 188.0443, 187.9542],
        [188.5750, 188

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.4979, 188.2028, 187.3472, 187.7875, 187.8414, 188.2174, 188.0806,
         187.9347, 188.8995, 188.0595],
        [187.6419, 188.3451, 187.4913, 187.9305, 187.9830, 188.3604, 188.2218,
         188.0784, 189.0410, 188.2009],
        [187.6128, 188.3153, 187.4633, 187.9014, 187.9553, 188.3307, 188.1928,
         188.0479, 189.0114, 188.1723],
        [187.6331, 188.3350, 187.4820, 187.9223, 187.9757, 188.3511, 188.2148,
         188.0683, 189.0332, 188.1940],
        [187.7446, 188.4487, 187.5964, 188.0337, 188.0876, 188.4635, 188.3247,
         188.1815, 189.1442, 188.3036],
        [187.3130, 188.0184, 187.1612, 187.6017, 187.6553, 188.0324, 187.8950,
         187.7493, 188.7136, 187.8740],
        [187.7060, 188.4099, 187.5571, 187.9962, 188.0506, 188.4252, 188.2887,
         188.1429, 189.1077, 188.2675],
        [187.5693, 188

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.8703, 187.8873, 188.1248, 188.0971, 188.1607, 187.7461, 187.6078,
         187.3067, 187.2781, 187.4571],
        [188.2400, 188.2567, 188.4927, 188.4660, 188.5294, 188.1173, 187.9782,
         187.6795, 187.6487, 187.8281],
        [188.1496, 188.1667, 188.4028, 188.3753, 188.4383, 188.0265, 187.8876,
         187.5875, 187.5582, 187.7379],
        [188.2224, 188.2397, 188.4754, 188.4480, 188.5109, 188.0998, 187.9606,
         187.6611, 187.6313, 187.8112],
        [188.0301, 188.0480, 188.2832, 188.2555, 188.3204, 187.9059, 187.7677,
         187.4667, 187.4388, 187.6173],
        [188.0214, 188.0383, 188.2761, 188.2482, 188.3102, 187.8978, 187.7590,
         187.4575, 187.4298, 187.6086],
        [188.0431, 188.0602, 188.2974, 188.2696, 188.3326, 187.9190, 187.7805,
         187.4785, 187.4520, 187.6293],
        [188.0001, 188

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.9848, 187.8299, 188.6313, 188.4619, 187.9387, 187.9008, 188.3543,
         188.3441, 187.6257, 188.2612],
        [187.9089, 187.7558, 188.5586, 188.3900, 187.8655, 187.8261, 188.2796,
         188.2726, 187.5514, 188.1861],
        [187.9170, 187.7641, 188.5648, 188.3963, 187.8739, 187.8340, 188.2875,
         188.2793, 187.5592, 188.1943],
        [187.9205, 187.7670, 188.5703, 188.4012, 187.8764, 187.8378, 188.2912,
         188.2837, 187.5632, 188.1979],
        [187.8892, 187.7348, 188.5358, 188.3659, 187.8442, 187.8054, 188.2591,
         188.2491, 187.5309, 188.1664],
        [187.9011, 187.7463, 188.5505, 188.3809, 187.8551, 187.8174, 188.2712,
         188.2630, 187.5426, 188.1778],
        [188.0568, 187.9037, 188.7063, 188.5362, 188.0147, 187.9755, 188.4283,
         188.4186, 187.7022, 188.3355],
        [188.0152, 187.8630, 188.6648, 188.4954, 187.97

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.3818, 187.8711, 187.5310, 187.3510, 188.1795, 187.4777, 188.0234,
         188.0460, 188.3125, 188.6204],
        [187.2874, 187.7784, 187.4368, 187.2587, 188.0876, 187.3823, 187.9314,
         187.9533, 188.2201, 188.5280],
        [187.2657, 187.7583, 187.4161, 187.2387, 188.0680, 187.3624, 187.9098,
         187.9325, 188.2011, 188.5074],
        [187.4136, 187.9062, 187.5655, 187.3864, 188.2155, 187.5129, 188.0553,
         188.0788, 188.3486, 188.6523],
        [187.1819, 187.6736, 187.3320, 187.1535, 187.9828, 187.2777, 187.8262,
         187.8484, 188.1155, 188.4232],
        [187.4213, 187.9104, 187.5700, 187.3905, 188.2187, 187.5159, 188.0637,
         188.0856, 188.3512, 188.6598],
        [187.3098, 187.7999, 187.4584, 187.2801, 188.1089, 187.4042, 187.9534,
         187.9754, 188.2418, 188.5505],
        [187.5265, 188

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.8539, 188.0189, 188.6161, 188.1938, 187.9874, 188.2526, 187.5263,
         187.9937, 187.3758, 188.1703],
        [188.7573, 187.9236, 188.5206, 188.0950, 187.8926, 188.1565, 187.4287,
         187.8941, 187.2801, 188.0724],
        [188.7724, 187.9390, 188.5355, 188.1094, 187.9080, 188.1715, 187.4438,
         187.9089, 187.2955, 188.0872],
        [188.9772, 188.1419, 188.7384, 188.3151, 188.1103, 188.3756, 187.6508,
         188.1178, 187.5009, 188.2932],
        [188.6703, 187.8353, 188.4302, 188.0086, 187.8036, 188.0681, 187.3432,
         187.8085, 187.1916, 187.9856],
        [188.9020, 188.0685, 188.6648, 188.2381, 188.0373, 188.3008, 187.5736,
         188.0388, 187.4259, 188.2164],
        [188.7145, 187.8801, 188.4775, 188.0546, 187.8487, 188.1138, 187.3864,
         187.8532, 187.2361, 188.0309],
        [188.7844, 187.9502, 188.5487, 188.1235, 187.91

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.5138, 187.4680, 188.0607, 188.0629, 188.0656, 188.6042, 187.6669,
         187.8666, 187.7453, 188.3741],
        [188.6830, 187.6367, 188.2283, 188.2335, 188.2350, 188.7732, 187.8357,
         188.0372, 187.9156, 188.5449],
        [188.5461, 187.4997, 188.0922, 188.0950, 188.0973, 188.6360, 187.6985,
         187.8984, 187.7772, 188.4061],
        [188.8532, 187.8080, 188.4010, 188.4034, 188.4063, 188.9447, 188.0078,
         188.2090, 188.0855, 188.7150],
        [188.5357, 187.4926, 188.0849, 188.0887, 188.0911, 188.6288, 187.6922,
         187.8938, 187.7698, 188.4005],
        [188.6799, 187.6342, 188.2273, 188.2292, 188.2321, 188.7707, 187.8336,
         188.0338, 187.9116, 188.5405],
        [188.6618, 187.6163, 188.2090, 188.2109, 188.2139, 188.7525, 187.8152,
         188.0157, 187.8934, 188.5221],
        [188.5000, 187

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.0386, 187.4076, 187.2336, 187.9691, 188.1389, 187.7694, 187.6990,
         188.5841, 187.6329, 187.4910],
        [188.3147, 187.6857, 187.5094, 188.2480, 188.4172, 188.0457, 187.9733,
         188.8622, 187.9093, 187.7684],
        [188.3943, 187.7657, 187.5886, 188.3273, 188.4963, 188.1251, 188.0524,
         188.9415, 187.9889, 187.8477],
        [188.3354, 187.7061, 187.5297, 188.2684, 188.4371, 188.0658, 187.9942,
         188.8826, 187.9301, 187.7879],
        [188.3342, 187.7024, 187.5264, 188.2635, 188.4317, 188.0624, 187.9910,
         188.8784, 187.9257, 187.7846],
        [188.3526, 187.7232, 187.5470, 188.2822, 188.4518, 188.0836, 188.0119,
         188.8974, 187.9470, 187.8049],
        [188.1855, 187.5529, 187.3778, 188.1142, 188.2836, 187.9142, 187.8413,
         188.7295, 187.7759, 187.6379],
        [188.2334, 187.6028, 187.4273, 188.1623, 188.33

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.6563, 188.0325, 188.2660, 188.0972, 188.0183, 187.9893, 187.7800,
         187.9203, 187.5597, 187.6771],
        [187.4509, 187.8270, 188.0614, 187.8919, 187.8132, 187.7839, 187.5746,
         187.7145, 187.3546, 187.4719],
        [187.8003, 188.1758, 188.4087, 188.2406, 188.1606, 188.1324, 187.9246,
         188.0624, 187.7040, 187.8206],
        [187.8031, 188.1784, 188.4140, 188.2455, 188.1630, 188.1360, 187.9285,
         188.0671, 187.7065, 187.8242],
        [187.7792, 188.1549, 188.3878, 188.2196, 188.1398, 188.1115, 187.9035,
         188.0415, 187.6829, 187.7996],
        [187.9391, 188.3145, 188.5499, 188.3803, 188.2982, 188.2729, 188.0662,
         188.2041, 187.8422, 187.9600],
        [187.7057, 188.0817, 188.3171, 188.1478, 188.0673, 188.0394, 187.8302,
         187.9713, 187.6088, 187.7270],
        [187.6979, 188.0738, 188.3094, 188.1399, 188.0593, 188.0318, 187.8227,
         187.9636, 187.6010, 187

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.9974, 188.0573, 187.9938, 187.8477, 187.9967, 188.0728, 188.7532,
         188.6817, 187.8234, 187.7675],
        [188.0548, 188.1167, 188.0536, 187.9081, 188.0528, 188.1317, 188.8107,
         188.7412, 187.8804, 187.8252],
        [188.0015, 188.0643, 187.9981, 187.8551, 187.9999, 188.0779, 188.7570,
         188.6873, 187.8249, 187.7709],
        [187.9549, 188.0173, 187.9514, 187.8078, 187.9535, 188.0310, 188.7104,
         188.6402, 187.7784, 187.7244],
        [188.3094, 188.3698, 188.3068, 188.1616, 188.3083, 188.3857, 189.0656,
         188.9949, 188.1364, 188.0798],
        [187.7280, 187.7901, 187.7249, 187.5798, 187.7266, 187.8038, 188.4834,
         188.4130, 187.5516, 187.4977],
        [187.9762, 188.0380, 187.9735, 187.8293, 187.9744, 188.0527, 188.7311,
         188.6623, 187.8013, 187.7451],
        [188.0521, 188.1126, 188.0495, 187.9031, 188.05

tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.9424, 187.4045, 187.6585, 187.6326, 188.0447, 187.9544, 188.7007,
         188.0686, 187.8931, 188.2748],
        [187.7191, 187.1820, 187.4355, 187.4109, 187.8232, 187.7341, 188.4804,
         187.8454, 187.6716, 188.0542],
        [187.6130, 187.0765, 187.3292, 187.3066, 187.7178, 187.6265, 188.3743,
         187.7401, 187.5671, 187.9485],
        [187.5533, 187.0158, 187.2680, 187.2429, 187.6565, 187.5657, 188.3134,
         187.6787, 187.5032, 187.8866],
        [187.5356, 186.9985, 187.2508, 187.2266, 187.6391, 187.5491, 188.2958,
         187.6615, 187.4866, 187.8698],
        [187.6100, 187.0734, 187.3262, 187.3033, 187.7152, 187.6239, 188.3724,
         187.7370, 187.5642, 187.9458],
        [187.6300, 187.0931, 187.3455, 187.3223, 187.7333, 187.6415, 188.3884,
         187.7567, 187.5819, 187.9635],
        [187.6206, 187.0841, 187.3366, 187.3143, 187.72

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.1962, 187.9696, 188.5093, 188.2088, 187.9020, 188.2335, 188.2019,
         188.1193, 188.9546, 187.4118],
        [187.9529, 187.7254, 188.2654, 187.9642, 187.6588, 187.9900, 187.9582,
         187.8753, 188.7106, 187.1683],
        [188.0249, 187.7987, 188.3370, 188.0381, 187.7294, 188.0614, 188.0299,
         187.9472, 188.7834, 187.2388],
        [187.7783, 187.5506, 188.0905, 187.7894, 187.4825, 187.8138, 187.7824,
         187.7000, 188.5358, 186.9899],
        [188.1066, 187.8801, 188.4191, 188.1187, 187.8127, 188.1438, 188.1116,
         188.0286, 188.8653, 187.3226],
        [188.1646, 187.9373, 188.4781, 188.1759, 187.8706, 188.2017, 188.1701,
         188.0876, 188.9227, 187.3799],
        [187.9441, 187.7169, 188.2571, 187.9560, 187.6479, 187.9794, 187.9486,
         187.8666, 188.7018, 187.1549],
        [187.9926, 187

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.9223, 187.7103, 187.4340, 187.7970, 188.1406, 187.2581, 187.4848,
         187.4137, 187.5601, 188.0220],
        [188.1433, 187.9282, 187.6528, 188.0171, 188.3579, 187.4792, 187.7031,
         187.6342, 187.7805, 188.2404],
        [188.2386, 188.0264, 187.7502, 188.1129, 188.4561, 187.5751, 187.8007,
         187.7302, 187.8762, 188.3376],
        [188.2408, 188.0262, 187.7503, 188.1143, 188.4552, 187.5769, 187.8006,
         187.7314, 187.8775, 188.3374],
        [188.1309, 187.9179, 187.6412, 188.0047, 188.3464, 187.4671, 187.6919,
         187.6215, 187.7672, 188.2281],
        [188.2073, 187.9924, 187.7164, 188.0807, 188.4210, 187.5433, 187.7668,
         187.6977, 187.8436, 188.3034],
        [188.0352, 187.8224, 187.5474, 187.9107, 188.2537, 187.3718, 187.5977,
         187.5282, 187.6747, 188.1359],
        [188.2997, 188.0854, 187.8097, 188.1734, 188.51

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.5966, 187.9447, 187.4936, 187.8710, 188.1563, 187.7406, 187.8219,
         187.6721, 187.7750, 187.7522],
        [187.5506, 187.8995, 187.4506, 187.8243, 188.1146, 187.6966, 187.7786,
         187.6261, 187.7264, 187.7041],
        [187.6146, 187.9636, 187.5148, 187.8873, 188.1786, 187.7608, 187.8423,
         187.6896, 187.7904, 187.7679],
        [187.6109, 187.9595, 187.5100, 187.8825, 188.1726, 187.7559, 187.8374,
         187.6853, 187.7879, 187.7654],
        [187.3318, 187.6816, 187.2296, 187.6064, 187.8937, 187.4769, 187.5580,
         187.4070, 187.5094, 187.4871],
        [187.5314, 187.8812, 187.4293, 187.8041, 188.0931, 187.6766, 187.7569,
         187.6055, 187.7091, 187.6861],
        [187.7203, 188.0677, 187.6186, 187.9935, 188.2809, 187.8646, 187.9465,
         187.7956, 187.8978, 187.8751],
        [187.5569, 187.9054, 187.4562, 187.8302, 188.1198, 187.7024, 187.7842,
         187.6322, 187.7332, 187

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.1404, 187.6508, 187.9913, 187.8427, 187.3274, 187.1658, 187.7820,
         187.4838, 187.6331, 188.5364],
        [187.1550, 187.6645, 188.0076, 187.8587, 187.3437, 187.1810, 187.7965,
         187.4991, 187.6496, 188.5501],
        [187.0118, 187.5224, 187.8609, 187.7124, 187.1972, 187.0348, 187.6516,
         187.3527, 187.5030, 188.4049],
        [187.2844, 187.7933, 188.1354, 187.9852, 187.4711, 187.3092, 187.9238,
         187.6255, 187.7769, 188.6780],
        [187.2937, 187.8022, 188.1458, 187.9961, 187.4823, 187.3192, 187.9337,
         187.6366, 187.7879, 188.6875],
        [187.2041, 187.7140, 188.0517, 187.9028, 187.3892, 187.2271, 187.8430,
         187.5449, 187.6940, 188.5966],
        [187.2255, 187.7359, 188.0744, 187.9257, 187.4117, 187.2503, 187.8663,
         187.5683, 187.7164, 188.6204],
        [187.3867, 187

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[188.2375, 187.5772, 187.8509, 187.1480, 187.6126, 187.0009, 187.1740,
         187.4675, 187.4906, 187.0934],
        [188.5320, 187.8721, 188.1436, 187.4448, 187.9113, 187.2978, 187.4702,
         187.7633, 187.7875, 187.3923],
        [188.3212, 187.6643, 187.9354, 187.2348, 187.7000, 187.0857, 187.2598,
         187.5516, 187.5777, 187.1800],
        [188.1795, 187.5206, 187.7930, 187.0907, 187.5570, 186.9416, 187.1164,
         187.4109, 187.4344, 187.0368],
        [188.2505, 187.5892, 187.8627, 187.1602, 187.6264, 187.0128, 187.1866,
         187.4815, 187.5037, 187.1069],
        [188.2922, 187.6328, 187.9051, 187.2052, 187.6707, 187.0573, 187.2294,
         187.5238, 187.5463, 187.1518],
        [188.2391, 187.5827, 187.8530, 187.1544, 187.6203, 187.0043, 187.1776,
         187.4716, 187.4955, 187.1006],
        [188.2362, 187.5798, 187.8498, 187.1511, 187.61

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.5968, 187.3839, 188.2074, 187.5245, 187.2461, 187.5500, 187.4657,
         187.2333, 187.9523, 187.2969],
        [188.1618, 187.9499, 188.7723, 188.0895, 187.8118, 188.1125, 188.0315,
         187.7989, 188.5170, 187.8618],
        [187.9205, 187.7082, 188.5309, 187.8484, 187.5708, 187.8735, 187.7896,
         187.5575, 188.2772, 187.6211],
        [187.8654, 187.6541, 188.4757, 187.7936, 187.5156, 187.8186, 187.7352,
         187.5028, 188.2221, 187.5659],
        [187.9672, 187.7558, 188.5781, 187.8943, 187.6167, 187.9186, 187.8367,
         187.6058, 188.3222, 187.6673],
        [188.0430, 187.8327, 188.6539, 187.9705, 187.6930, 187.9946, 187.9131,
         187.6823, 188.3986, 187.7433],
        [187.9648, 187.7541, 188.5760, 187.8918, 187.6138, 187.9151, 187.8349,
         187.6040, 188.3186, 187.6644],
        [187.9720, 187.7598, 188.5822, 187.9008, 187.62

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.4278, 186.8977, 187.5577, 187.5054, 187.1974, 187.5483, 187.4757,
         188.2837, 187.7585, 187.4886],
        [187.5361, 187.0081, 187.6649, 187.6147, 187.3052, 187.6562, 187.5822,
         188.3905, 187.8665, 187.5965],
        [187.4375, 186.9093, 187.5650, 187.5183, 187.2066, 187.5556, 187.4839,
         188.2920, 187.7688, 187.4960],
        [187.3776, 186.8481, 187.5078, 187.4565, 187.1476, 187.4982, 187.4262,
         188.2336, 187.7089, 187.4382],
        [187.5793, 187.0482, 187.7066, 187.6562, 187.3479, 187.6971, 187.6257,
         188.4345, 187.9090, 187.6385],
        [187.8051, 187.2766, 187.9311, 187.8823, 187.5730, 187.9228, 187.8493,
         188.6584, 188.1342, 187.8640],
        [187.7767, 187.2485, 187.9037, 187.8539, 187.5450, 187.8956, 187.8215,
         188.6301, 188.1060, 187.8363],
        [187.7236, 187.1944, 187.8523, 187.7998, 187.49

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.5763, 187.1146, 187.5092, 187.5642, 188.1632, 187.2681, 187.2761,
         187.1716, 186.7701, 187.9944],
        [187.8120, 187.3532, 187.7458, 187.8012, 188.3987, 187.5047, 187.5132,
         187.4079, 187.0056, 188.2292],
        [187.8429, 187.3834, 187.7774, 187.8324, 188.4301, 187.5359, 187.5466,
         187.4393, 187.0367, 188.2607],
        [187.7485, 187.2891, 187.6821, 187.7371, 188.3352, 187.4409, 187.4486,
         187.3440, 186.9423, 188.1660],
        [187.9604, 187.4987, 187.8933, 187.9494, 188.5474, 187.6530, 187.6638,
         187.5566, 187.1537, 188.3782],
        [187.6811, 187.2198, 187.6135, 187.6693, 188.2676, 187.3729, 187.3808,
         187.2764, 186.8744, 188.0985],
        [187.8520, 187.3901, 187.7844, 187.8385, 188.4383, 187.5436, 187.5503,
         187.4467, 187.0466, 188.2707],
        [187.8646, 187

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.9286, 188.4680, 187.8430, 188.0057, 187.4835, 187.1383, 188.3808,
         187.5995, 187.3218, 187.6592],
        [187.8172, 188.3564, 187.7305, 187.8947, 187.3727, 187.0263, 188.2713,
         187.4902, 187.2105, 187.5473],
        [187.7118, 188.2498, 187.6241, 187.7872, 187.2638, 186.9210, 188.1635,
         187.3819, 187.1027, 187.4412],
        [187.4110, 187.9488, 187.3224, 187.4869, 186.9627, 186.6193, 187.8622,
         187.0818, 186.8006, 187.1392],
        [187.7462, 188.2855, 187.6594, 187.8242, 187.3024, 186.9549, 188.2009,
         187.4198, 187.1395, 187.4761],
        [187.4975, 188.0353, 187.4088, 187.5733, 187.0496, 186.7059, 187.9496,
         187.1687, 186.8877, 187.2260],
        [187.6514, 188.1901, 187.5639, 187.7288, 187.2063, 186.8600, 188.1055,
         187.3244, 187.0438, 187.3809],
        [187.9824, 188

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.8771, 186.8656, 187.3408, 187.2997, 188.2570, 187.2839, 187.2673,
         188.4134, 187.6564, 187.6624],
        [186.6633, 186.6554, 187.1298, 187.0891, 188.0470, 187.0711, 187.0570,
         188.2035, 187.4467, 187.4520],
        [186.8204, 186.8127, 187.2861, 187.2450, 188.2045, 187.2311, 187.2145,
         188.3599, 187.6031, 187.6084],
        [186.6099, 186.6019, 187.0758, 187.0353, 187.9942, 187.0200, 187.0042,
         188.1493, 187.3932, 187.3983],
        [187.0635, 187.0526, 187.5256, 187.4865, 188.4448, 187.4734, 187.4557,
         188.6001, 187.8425, 187.8484],
        [186.8993, 186.8864, 187.3615, 187.3221, 188.2782, 187.3051, 187.2890,
         188.4346, 187.6773, 187.6835],
        [186.9702, 186.9572, 187.4323, 187.3929, 188.3487, 187.3754, 187.3596,
         188.5056, 187.7479, 187.7542],
        [186.8280, 186

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.4172, 187.5930, 187.2724, 186.8761, 187.5653, 187.3578, 186.9555,
         186.7041, 188.1013, 187.1515],
        [187.6299, 187.8052, 187.4852, 187.0902, 187.7782, 187.5695, 187.1684,
         186.9157, 188.3138, 187.3652],
        [187.6139, 187.7875, 187.4691, 187.0714, 187.7612, 187.5524, 187.1525,
         186.9005, 188.2971, 187.3479],
        [187.5204, 187.6945, 187.3756, 186.9786, 187.6678, 187.4596, 187.0587,
         186.8070, 188.2039, 187.2545],
        [187.7653, 187.9391, 187.6207, 187.2246, 187.9130, 187.7035, 187.3042,
         187.0509, 188.4485, 187.5005],
        [187.7211, 187.8947, 187.5765, 187.1806, 187.8688, 187.6590, 187.2595,
         187.0064, 188.4042, 187.4561],
        [187.4295, 187.6060, 187.2845, 186.8892, 187.5781, 187.3699, 186.9672,
         186.7154, 188.1135, 187.1640],
        [187.5366, 187

       grad_fn=<CdistBackward0>)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.3557, 188.0091, 186.8722, 187.3061, 187.6484, 187.1877, 187.4848,
         187.4114, 187.2283, 186.7673],
        [187.3439, 188.0006, 186.8582, 187.2941, 187.6378, 187.1723, 187.4722,
         187.4017, 187.2187, 186.7555],
        [187.4389, 188.0947, 186.9548, 187.3894, 187.7325, 187.2688, 187.5673,
         187.4962, 187.3128, 186.8503],
        [187.2909, 187.9453, 186.8080, 187.2420, 187.5843, 187.1230, 187.4214,
         187.3483, 187.1640, 186.7041],
        [187.4733, 188.1296, 186.9881, 187.4236, 187.7668, 187.3020, 187.6011,
         187.5306, 187.3477, 186.8844],
        [187.4211, 188.0770, 186.9370, 187.3715, 187.7149, 187.2510, 187.5495,
         187.4783, 187.2949, 186.8323],
        [187.4164, 188.0721, 186.9324, 187.3680, 187.7096, 187.2469, 187.5466,
         187.4746, 187.2905, 186.8301],
        [187.3361, 187

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.1535, 187.6732, 187.3044, 187.6006, 187.5684, 187.3254, 186.8055,
         188.1957, 187.3756, 188.0506],
        [186.8140, 187.3361, 186.9679, 187.2640, 187.2304, 186.9890, 186.4701,
         187.8607, 187.0395, 187.7121],
        [187.0203, 187.5411, 187.1728, 187.4688, 187.4359, 187.1938, 186.6744,
         188.0648, 187.2440, 187.9180],
        [187.0894, 187.6095, 187.2406, 187.5364, 187.5050, 187.2617, 186.7423,
         188.1323, 187.3119, 187.9874],
        [187.0612, 187.5819, 187.2138, 187.5087, 187.4786, 187.2354, 186.7166,
         188.1060, 187.2852, 187.9599],
        [187.1854, 187.7062, 187.3387, 187.6347, 187.6006, 187.3593, 186.8399,
         188.2305, 187.4096, 188.0818],
        [186.9008, 187.4224, 187.0542, 187.3499, 187.3175, 187.0755, 186.5566,
         187.9468, 187.1258, 187.7995],
        [187.2990, 187

tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.7220, 187.2888, 187.5054, 188.2650, 187.5957, 187.0981, 187.2236,
         187.4373, 188.0333, 187.3726],
        [186.5773, 187.1426, 187.3589, 188.1201, 187.4497, 186.9519, 187.0790,
         187.2938, 187.8875, 187.2277],
        [186.7042, 187.2701, 187.4844, 188.2483, 187.5759, 187.0798, 187.2065,
         187.4214, 188.0141, 187.3552],
        [186.6954, 187.2622, 187.4793, 188.2383, 187.5699, 187.0710, 187.1972,
         187.4109, 188.0054, 187.3468],
        [186.4724, 187.0378, 187.2547, 188.0152, 187.3450, 186.8472, 186.9739,
         187.1886, 187.7828, 187.1231],
        [186.7117, 187.2783, 187.4954, 188.2544, 187.5861, 187.0871, 187.2134,
         187.4271, 188.0215, 187.3629],
        [186.4048, 186.9711, 187.1877, 187.9485, 187.2782, 186.7806, 186.9063,
         187.1207, 187.7135, 187.0572],
        [186.5989, 187.1638, 187.3802, 188.1416, 187.47

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.5947, 187.1738, 187.5275, 187.8361, 187.1420, 186.7389, 187.4040,
         187.9604, 188.0644, 186.9426],
        [187.3006, 186.8785, 187.2346, 187.5433, 186.8475, 186.4453, 187.1120,
         187.6672, 187.7703, 186.6498],
        [187.7840, 187.3644, 187.7187, 188.0285, 187.3333, 186.9307, 187.5964,
         188.1521, 188.2558, 187.1345],
        [187.6155, 187.1948, 187.5484, 187.8571, 187.1631, 186.7599, 187.4249,
         187.9814, 188.0854, 186.9635],
        [187.6570, 187.2358, 187.5890, 187.8981, 187.2026, 186.8002, 187.4654,
         188.0211, 188.1260, 187.0047],
        [187.3547, 186.9321, 187.2880, 187.5968, 186.9002, 186.4987, 187.1631,
         187.7204, 187.8235, 186.7021],
        [187.4679, 187.0455, 187.3997, 187.7085, 187.0121, 186.6104, 187.2748,
         187.8314, 187.9359, 186.8146],
        [187.5293, 187

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.0029, 187.3035, 187.2593, 187.3640, 187.1064, 187.3189, 187.8035,
         188.3452, 187.8985, 187.3012],
        [186.9689, 187.2693, 187.2253, 187.3302, 187.0723, 187.2849, 187.7692,
         188.3113, 187.8648, 187.2670],
        [187.0071, 187.3092, 187.2641, 187.3671, 187.1110, 187.3223, 187.8099,
         188.3484, 187.9007, 187.3047],
        [186.8700, 187.1727, 187.1258, 187.2285, 186.9732, 187.1843, 187.6706,
         188.2099, 187.7630, 187.1666],
        [186.9335, 187.2354, 187.1888, 187.2951, 187.0375, 187.2491, 187.7347,
         188.2758, 187.8286, 187.2300],
        [186.8729, 187.1746, 187.1295, 187.2323, 186.9762, 187.1876, 187.6746,
         188.2137, 187.7665, 187.1699],
        [186.6495, 186.9523, 186.9041, 187.0083, 186.7524, 186.9631, 187.4493,
         187.9893, 187.5428, 186.9442],
        [186.8423, 187

       grad_fn=<CdistBackward0>)
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.6796, 187.3587, 187.5782, 187.6370, 186.8802, 187.5518, 188.0888,
         187.2185, 187.2181, 187.3692],
        [186.9123, 187.5895, 187.8124, 187.8697, 187.1149, 187.7859, 188.3202,
         187.4513, 187.4512, 187.6039],
        [186.7177, 187.3946, 187.6188, 187.6739, 186.9204, 187.5939, 188.1257,
         187.2567, 187.2562, 187.4103],
        [186.6988, 187.3755, 187.5983, 187.6567, 186.9003, 187.5725, 188.1076,
         187.2379, 187.2376, 187.3895],
        [186.5933, 187.2728, 187.4932, 187.5500, 186.7945, 187.4674, 188.0027,
         187.1325, 187.1320, 187.2842],
        [186.6755, 187.3521, 187.5747, 187.6314, 186.8769, 187.5500, 188.0839,
         187.2139, 187.2130, 187.3662],
        [186.8320, 187.5111, 187.7306, 187.7881, 187.0333, 187.7041, 188.2405,
         187.3704, 187.3699, 187.5220],
        [186.7181, 187

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.0731, 187.2246, 187.8978, 187.4733, 187.5202, 187.4462, 188.1989,
         187.1741, 187.2379, 187.1533],
        [186.9215, 187.0738, 187.7462, 187.3244, 187.3687, 187.2968, 188.0486,
         187.0235, 187.0866, 187.0022],
        [186.9944, 187.1475, 187.8214, 187.3967, 187.4432, 187.3695, 188.1223,
         187.0972, 187.1605, 187.0760],
        [187.2026, 187.3526, 188.0238, 187.6019, 187.6471, 187.5744, 188.3270,
         187.3023, 187.3655, 187.2816],
        [186.7535, 186.9069, 187.5784, 187.1580, 187.2017, 187.1290, 187.8814,
         186.8562, 186.9204, 186.8350],
        [187.3363, 187.4870, 188.1599, 187.7356, 187.7818, 187.7091, 188.4620,
         187.4370, 187.4994, 187.4160],
        [187.0026, 187.1542, 187.8267, 187.4033, 187.4496, 187.3759, 188.1283,
         187.1035, 187.1674, 187.0827],
        [187.1260, 187

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.0747, 187.2515, 187.4212, 187.2997, 187.0877, 186.9737, 187.4337,
         187.9528, 187.2739, 187.4182],
        [187.0098, 187.1851, 187.3573, 187.2366, 187.0253, 186.9105, 187.3688,
         187.8885, 187.2108, 187.3541],
        [187.3379, 187.5138, 187.6826, 187.5633, 187.3515, 187.2382, 187.6957,
         188.2153, 187.5363, 187.6796],
        [187.1980, 187.3739, 187.5439, 187.4238, 187.2116, 187.0983, 187.5557,
         188.0757, 187.3975, 187.5406],
        [187.3294, 187.5049, 187.6745, 187.5553, 187.3438, 187.2301, 187.6873,
         188.2070, 187.5284, 187.6712],
        [187.0928, 187.2671, 187.4379, 187.3202, 187.1089, 186.9950, 187.4508,
         187.9711, 187.2923, 187.4354],
        [186.8559, 187.0311, 187.2030, 187.0828, 186.8716, 186.7566, 187.2153,
         187.7349, 187.0563, 187.2006],
        [187.1638, 187

tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.5406, 187.1805, 187.7255, 187.1936, 187.2210, 187.4618, 187.5262,
         188.1947, 186.6550, 186.7346],
        [187.5068, 187.1501, 187.6955, 187.1616, 187.1878, 187.4290, 187.4930,
         188.1628, 186.6239, 186.7034],
        [187.1513, 186.7902, 187.3351, 186.8024, 186.8294, 187.0706, 187.1353,
         187.8040, 186.2644, 186.3410],
        [187.3397, 186.9775, 187.5219, 186.9911, 187.0181, 187.2588, 187.3241,
         187.9915, 186.4520, 186.5295],
        [187.2999, 186.9411, 187.4868, 186.9528, 186.9793, 187.2200, 187.2852,
         187.9539, 186.4153, 186.4924],
        [187.4613, 187.1029, 187.6480, 187.1150, 187.1420, 187.3837, 187.4468,
         188.1169, 186.5770, 186.6572],
        [187.1962, 186.8368, 187.3812, 186.8489, 186.8746, 187.1158, 187.1805,
         187.8491, 186.3105, 186.3870],
        [187.4200, 187.0586, 187.6041, 187.0716, 187.09

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.3069, 187.1826, 186.8221, 186.8015, 186.6482, 186.3035, 186.7483,
         186.9238, 186.8867, 186.7741],
        [187.6245, 187.4991, 187.1376, 187.1168, 186.9662, 186.6182, 187.0642,
         187.2381, 187.2006, 187.0899],
        [187.6224, 187.4978, 187.1377, 187.1171, 186.9646, 186.6187, 187.0645,
         187.2400, 187.2013, 187.0891],
        [187.4585, 187.3333, 186.9723, 186.9514, 186.7995, 186.4533, 186.8982,
         187.0725, 187.0357, 186.9245],
        [187.5910, 187.4655, 187.1037, 187.0829, 186.9326, 186.5842, 187.0303,
         187.2040, 187.1669, 187.0562],
        [187.3393, 187.2148, 186.8540, 186.8328, 186.6801, 186.3355, 186.7795,
         186.9540, 186.9172, 186.8067],
        [187.5685, 187.4455, 187.0859, 187.0648, 186.9113, 186.5676, 187.0122,
         187.1884, 187.1484, 187.0383],
        [187.5060, 187.3824, 187.0215, 187.0000, 186.84

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.9358, 187.2617, 187.0869, 186.9639, 187.9356, 187.2591, 187.0034,
         187.0608, 187.2572, 187.7045],
        [186.9586, 187.2840, 187.1090, 186.9862, 187.9575, 187.2810, 187.0261,
         187.0829, 187.2797, 187.7275],
        [186.9479, 187.2717, 187.0951, 186.9738, 187.9460, 187.2684, 187.0141,
         187.0695, 187.2683, 187.7184],
        [186.8975, 187.2225, 187.0457, 186.9252, 187.8981, 187.2208, 186.9648,
         187.0198, 187.2190, 187.6686],
        [187.0163, 187.3418, 187.1665, 187.0447, 188.0161, 187.3399, 187.0845,
         187.1401, 187.3379, 187.7864],
        [186.9659, 187.2903, 187.1132, 186.9929, 187.9664, 187.2883, 187.0328,
         187.0875, 187.2870, 187.7378],
        [186.8814, 187.2077, 187.0324, 186.9100, 187.8825, 187.2058, 186.9493,
         187.0064, 187.2032, 187.6507],
        [186.9219, 187

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.2048, 186.6451, 187.8092, 187.4910, 187.2815, 187.1346, 187.0517,
         187.2520, 187.3065, 186.9221],
        [187.0751, 186.5150, 187.6793, 187.3614, 187.1512, 187.0053, 186.9216,
         187.1230, 187.1769, 186.7923],
        [187.2965, 186.7386, 187.9017, 187.5853, 187.3742, 187.2294, 187.1433,
         187.3456, 187.3970, 187.0160],
        [186.9454, 186.3869, 187.5495, 187.2315, 187.0225, 186.8758, 186.7908,
         186.9926, 187.0473, 186.6626],
        [187.1620, 186.6036, 187.7673, 187.4515, 187.2394, 187.0957, 187.0088,
         187.2127, 187.2627, 186.8818],
        [187.1159, 186.5570, 187.7209, 187.4044, 187.1929, 187.0486, 186.9624,
         187.1658, 187.2169, 186.8349],
        [187.1652, 186.6076, 187.7699, 187.4526, 187.2421, 187.0974, 187.0107,
         187.2133, 187.2655, 186.8838],
        [187.2576, 186

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.2095, 186.9951, 186.5945, 186.9104, 186.4238, 187.4150, 187.1487,
         187.7341, 187.8821, 186.8705],
        [187.4794, 187.2647, 186.8665, 187.1810, 186.6981, 187.6867, 187.4178,
         188.0044, 188.1532, 187.1418],
        [187.1899, 186.9762, 186.5748, 186.8900, 186.4056, 187.3961, 187.1282,
         187.7144, 187.8635, 186.8507],
        [187.3913, 187.1784, 186.7777, 187.0933, 186.6081, 187.5986, 187.3295,
         187.9163, 188.0655, 187.0529],
        [187.2404, 187.0258, 186.6261, 186.9403, 186.4581, 187.4475, 187.1787,
         187.7651, 187.9144, 186.9023],
        [186.9988, 186.7858, 186.3833, 186.6984, 186.2133, 187.2060, 186.9373,
         187.5235, 187.6728, 186.6604],
        [187.3348, 187.1196, 186.7214, 187.0363, 186.5514, 187.5421, 187.2742,
         187.8599, 188.0078, 186.9979],
        [187.3502, 187

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.6045, 186.5910, 186.8330, 187.0591, 187.1510, 187.3916, 186.8451,
         186.8858, 186.6955, 187.1752],
        [187.7531, 186.7374, 186.9801, 187.2062, 187.2976, 187.5371, 186.9926,
         187.0326, 186.8418, 187.3248],
        [187.6629, 186.6470, 186.8893, 187.1142, 187.2054, 187.4451, 186.9001,
         186.9421, 186.7514, 187.2336],
        [187.7845, 186.7694, 187.0125, 187.2381, 187.3291, 187.5690, 187.0238,
         187.0646, 186.8741, 187.3561],
        [187.7113, 186.6960, 186.9385, 187.1630, 187.2533, 187.4937, 186.9483,
         186.9909, 186.8004, 187.2817],
        [187.3992, 186.3846, 186.6257, 186.8514, 186.9440, 187.1836, 186.6376,
         186.6790, 186.4886, 186.9695],
        [187.9152, 186.9002, 187.1438, 187.3696, 187.4606, 187.7013, 187.1560,
         187.1961, 187.0054, 187.4866],
        [187.5483, 186

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.1299, 186.7454, 186.0645, 186.9212, 187.5122, 187.3444, 186.3064,
         186.5957, 186.9159, 187.5945],
        [187.3287, 186.9431, 186.2634, 187.1180, 187.7098, 187.5428, 186.5062,
         186.7960, 187.1146, 187.7920],
        [187.4236, 187.0421, 186.3592, 187.2147, 187.8069, 187.6401, 186.6029,
         186.8930, 187.2114, 187.8885],
        [187.3497, 186.9612, 186.2834, 187.1379, 187.7285, 187.5605, 186.5257,
         186.8141, 187.1330, 187.8124],
        [187.5577, 187.1765, 186.4923, 187.3502, 187.9406, 187.7726, 186.7365,
         187.0233, 187.3431, 188.0229],
        [187.3378, 186.9520, 186.2717, 187.1281, 187.7180, 187.5496, 186.5143,
         186.8020, 187.1214, 187.8018],
        [187.3199, 186.9376, 186.2549, 187.1118, 187.7029, 187.5352, 186.4978,
         186.7867, 187.1064, 187.7851],
        [187.3166, 186

       grad_fn=<CdistBackward0>)
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.3519, 186.6919, 186.8360, 187.4598, 187.0174, 187.3137, 187.1856,
         187.3229, 186.9679, 186.9029],
        [187.0499, 186.3885, 186.5343, 187.1593, 186.7163, 187.0138, 186.8834,
         187.0232, 186.6651, 186.5991],
        [187.2050, 186.5442, 186.6898, 187.3153, 186.8727, 187.1687, 187.0385,
         187.1790, 186.8210, 186.7552],
        [187.2160, 186.5548, 186.7004, 187.3259, 186.8830, 187.1793, 187.0493,
         187.1892, 186.8317, 186.7663],
        [187.3130, 186.6544, 186.7984, 187.4216, 186.9805, 187.2752, 187.1478,
         187.2862, 186.9297, 186.8642],
        [187.3841, 186.7249, 186.8688, 187.4925, 187.0509, 187.3458, 187.2180,
         187.3562, 187.0007, 186.9356],
        [187.0644, 186.4028, 186.5484, 187.1721, 186.7295, 187.0271, 186.8968,
         187.0360, 186.6791, 186.6135],
        [186.9658, 186

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.3164, 187.0984, 187.7653, 186.2812, 186.7171, 186.9129, 186.8817,
         186.7384, 187.0635, 186.4758],
        [186.2859, 187.0706, 187.7364, 186.2541, 186.6891, 186.8832, 186.8541,
         186.7096, 187.0350, 186.4468],
        [186.3409, 187.1256, 187.7918, 186.3101, 186.7443, 186.9388, 186.9098,
         186.7655, 187.0909, 186.5018],
        [186.3488, 187.1332, 187.7995, 186.3174, 186.7541, 186.9449, 186.9178,
         186.7724, 187.0963, 186.5110],
        [186.3387, 187.1197, 187.7876, 186.3032, 186.7385, 186.9359, 186.9039,
         186.7610, 187.0863, 186.4975],
        [186.1578, 186.9416, 187.6093, 186.1253, 186.5612, 186.7561, 186.7272,
         186.5821, 186.9075, 186.3197],
        [186.5562, 187.3371, 188.0040, 186.5211, 186.9569, 187.1515, 187.1204,
         186.9776, 187.3017, 186.7148],
        [186.0504, 186

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.2695, 186.7928, 187.1171, 187.0873, 187.6011, 186.7356, 186.8057,
         186.7391, 186.4665, 186.2358],
        [186.3667, 186.8914, 187.2117, 187.1846, 187.6984, 186.8326, 186.9028,
         186.8399, 186.5631, 186.3354],
        [186.3835, 186.9075, 187.2294, 187.2014, 187.7150, 186.8499, 186.9198,
         186.8540, 186.5809, 186.3505],
        [186.2097, 186.7351, 187.0559, 187.0277, 187.5420, 186.6756, 186.7464,
         186.6838, 186.4061, 186.1793],
        [186.3446, 186.8692, 187.1900, 187.1624, 187.6763, 186.8105, 186.8808,
         186.8176, 186.5410, 186.3132],
        [186.2556, 186.7798, 187.1031, 187.0729, 187.5873, 186.7215, 186.7923,
         186.7262, 186.4520, 186.2225],
        [186.2156, 186.7409, 187.0620, 187.0336, 187.5479, 186.6818, 186.7523,
         186.6890, 186.4124, 186.1847],
        [186.5095, 187

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.0531, 187.0313, 187.1681, 187.0284, 186.3669, 186.9922, 186.9622,
         187.1667, 187.1647, 187.7429],
        [186.8124, 186.7892, 186.9270, 186.7853, 186.1247, 186.7481, 186.7185,
         186.9254, 186.9228, 187.5011],
        [186.8805, 186.8605, 186.9966, 186.8567, 186.1958, 186.8207, 186.7907,
         186.9957, 186.9934, 187.5722],
        [187.1765, 187.1538, 187.2911, 187.1511, 186.4901, 187.1154, 187.0850,
         187.2895, 187.2876, 187.8660],
        [186.6993, 186.6780, 186.8151, 186.6734, 186.0136, 186.6367, 186.6071,
         186.8143, 186.8112, 187.3902],
        [186.8664, 186.8441, 186.9817, 186.8403, 186.1803, 186.8040, 186.7743,
         186.9806, 186.9778, 187.5562],
        [186.8439, 186.8232, 186.9586, 186.8193, 186.1571, 186.7823, 186.7520,
         186.9573, 186.9556, 187.5348],
        [187.1304, 187

       grad_fn=<CdistBackward0>)
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.0826, 187.6898, 186.7568, 186.6940, 186.5594, 186.6908, 187.0304,
         186.9769, 187.6071, 186.5944],
        [187.0778, 187.6848, 186.7534, 186.6898, 186.5554, 186.6864, 187.0270,
         186.9729, 187.6038, 186.5894],
        [187.1239, 187.7318, 186.7992, 186.7360, 186.6019, 186.7321, 187.0731,
         187.0198, 187.6493, 186.6366],
        [187.1683, 187.7749, 186.8418, 186.7807, 186.6444, 186.7767, 187.1153,
         187.0627, 187.6910, 186.6809],
        [187.0940, 187.7009, 186.7695, 186.7060, 186.5715, 186.7026, 187.0430,
         186.9889, 187.6198, 186.6055],
        [187.3251, 187.9325, 187.0006, 186.9379, 186.8029, 186.9344, 187.2738,
         187.2208, 187.8497, 186.8389],
        [187.0187, 187.6261, 186.6949, 186.6312, 186.4971, 186.6267, 186.9691,
         186.9151, 187.5455, 186.5305],
        [187.1213, 187

       grad_fn=<CdistBackward0>)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[187.3568, 186.1136, 186.1320, 187.0159, 186.1631, 186.7837, 186.5383,
         186.7272, 186.6449, 186.9391],
        [187.4466, 186.2003, 186.2200, 187.1028, 186.2503, 186.8740, 186.6257,
         186.8137, 186.7345, 187.0251],
        [187.4290, 186.1814, 186.2023, 187.0842, 186.2305, 186.8538, 186.6085,
         186.7954, 186.7144, 187.0082],
        [187.0976, 185.8529, 185.8711, 186.7549, 185.8988, 186.5238, 186.2775,
         186.4654, 186.3836, 186.6776],
        [187.3917, 186.1474, 186.1659, 187.0498, 186.1977, 186.8195, 186.5720,
         186.7606, 186.6804, 186.9721],
        [187.3294, 186.0855, 186.1042, 186.9879, 186.1347, 186.7561, 186.5105,
         186.6990, 186.6171, 186.9110],
        [187.4829, 186.2405, 186.2585, 187.1425, 186.2909, 186.9115, 186.6643,
         186.8539, 186.7726, 187.0648],
        [187.3494, 186

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.9676, 187.6714, 186.5818, 186.9342, 186.1968, 186.6022, 186.6801,
         187.0967, 186.8327, 186.4263],
        [186.9230, 187.6271, 186.5368, 186.8941, 186.1537, 186.5582, 186.6376,
         187.0531, 186.7902, 186.3825],
        [186.9360, 187.6409, 186.5517, 186.9039, 186.1674, 186.5706, 186.6516,
         187.0665, 186.8020, 186.3950],
        [186.9327, 187.6366, 186.5470, 186.8989, 186.1620, 186.5669, 186.6452,
         187.0619, 186.7976, 186.3914],
        [186.9502, 187.6525, 186.5632, 186.9189, 186.1789, 186.5852, 186.6624,
         187.0787, 186.8155, 186.4090],
        [186.9286, 187.6315, 186.5424, 186.8987, 186.1589, 186.5636, 186.6430,
         187.0580, 186.7947, 186.3877],
        [187.1705, 187.8745, 186.7859, 187.1393, 186.4011, 186.8069, 186.8854,
         187.3003, 187.0368, 186.6293],
        [186.6006, 187

       grad_fn=<CdistBackward0>)
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
Epoch: 595 Training Loss: 13.978919267654419 Training Accuracy: 0.1
	Learning_rate: 1e-06
tensor([[186.6825, 186.4389, 186.3871, 186.0106, 186.1929, 186.1833, 186.6735,
         186.3223, 186.5135, 186.4649],
        [186.9153, 186.6710, 186.6206, 186.2435, 186.4258, 186.4181, 186.9065,
         186.5554, 186.7469, 186.6994],
        [186.9353, 186.6899, 186.6408, 186.2635, 186.4464, 186.4381, 186.9259,
         186.5757, 186.7671, 186.7200],
        [187.0882, 186.8472, 186.7936, 186.4164, 186.5987, 186.5928, 187.0804,
         186.7288, 186.9203, 186.8714],
        [187.0698, 186.8262, 186.7754, 186.3972, 186.5794, 186.5730, 187.0614,
         186.7095, 186.9009, 186.8526],
        [187.0093, 186.7686, 186.7144, 186.3376, 186.5197, 186.5133, 187.0016,
         186.6497, 186.8413, 186.7924],
        [187.0514, 186.8077, 186.7571, 186.3785, 18

       grad_fn=<CdistBackward0>)
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
Epoch: 596 Training Loss: 13.987800598144531 Training Accuracy: 0.1
	Learning_rate: 1e-06
tensor([[186.8465, 186.8806, 186.6125, 186.5023, 186.0753, 186.6545, 186.6749,
         186.0574, 186.6073, 186.9992],
        [187.0585, 187.0933, 186.8244, 186.7125, 186.2881, 186.8701, 186.8909,
         186.2704, 186.8204, 187.2124],
        [186.8491, 186.8833, 186.6143, 186.5038, 186.0782, 186.6587, 186.6790,
         186.0598, 186.6099, 187.0014],
        [186.9869, 187.0208, 186.7526, 186.6418, 186.2164, 186.7952, 186.8160,
         186.1975, 186.7479, 187.1396],
        [186.8700, 186.9044, 186.6356, 186.5246, 186.0990, 186.6803, 186.7007,
         186.0813, 186.6310, 187.0229],
        [186.8076, 186.8392, 186.5717, 186.4595, 186.0331, 186.6145, 186.6356,
         186.0165, 186.5663, 186.9581],
        [186.8264, 186.8612, 186.5939, 186.4822, 18

Epoch: 597 Training Loss: 13.964612007141113 Training Accuracy: 0.1
	Learning_rate: 1e-06
tensor([[186.5361, 187.6663, 187.5260, 187.1436, 186.8086, 186.7109, 186.8610,
         186.9101, 186.6514, 186.6085],
        [186.3955, 187.5253, 187.3857, 187.0022, 186.6689, 186.5699, 186.7203,
         186.7690, 186.5101, 186.4669],
        [186.6326, 187.7612, 187.6223, 187.2362, 186.9059, 186.8060, 186.9565,
         187.0018, 186.7477, 186.7007],
        [186.2044, 187.3336, 187.1950, 186.8094, 186.4771, 186.3786, 186.5301,
         186.5768, 186.3182, 186.2741],
        [186.7108, 187.8416, 187.7008, 187.3166, 186.9831, 186.8862, 187.0364,
         187.0830, 186.8270, 186.7822],
        [186.4228, 187.5527, 187.4129, 187.0292, 186.6963, 186.5973, 186.7477,
         186.7961, 186.5375, 186.4942],
        [186.4799, 187.6092, 187.4698, 187.0862, 186.7539, 186.6538, 186.8040,
         186.8526, 186.5944, 186.5507],
        [186.5147, 187.6444, 187.5051, 187.1182, 186.7884, 186.6891, 186.8401

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
Epoch: 598 Training Loss: 14.009255409240723 Training Accuracy: 0.1
	Learning_rate: 1e-06
tensor([[186.7661, 186.9125, 186.8363, 186.7346, 187.7538, 186.8238, 187.1011,
         186.9754, 187.5219, 186.9457],
        [186.4268, 186.5730, 186.4969, 186.3960, 187.4161, 186.4836, 186.7623,
         186.6382, 187.1828, 186.6076],
        [186.5641, 186.7123, 186.6352, 186.5334, 187.5535, 186.6225, 186.9008,
         186.7764, 187.3216, 186.7445],
        [186.2278, 186.3769, 186.2995, 186.1971, 187.2179, 186.2866, 186.5657,
         186.4422, 186.9866, 186.4084],
        [186.5185, 186.6658, 186.5896, 186.4869, 187.5065, 186.5768, 186.8547,
         186.7296, 187.2768, 186.6975],
        [186.4738, 186.6196, 186.5434, 186.4422, 187.4619, 186.5309, 186.8088,
         186.6839, 187.2291, 186.6540],
        [186.4587, 186.6085, 186.5310, 186.4276, 187.4477, 186.5185, 186.7967,
         186.6722, 187.2186, 186.6380],
        [186.3951, 18

tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
Epoch: 599 Training Loss: 14.001387596130371 Training Accuracy: 0.1
	Learning_rate: 1e-06
tensor([[186.4520, 186.7368, 186.0979, 186.4122, 186.3496, 185.9028, 186.2788,
         186.6277, 186.6905, 186.4164],
        [186.5905, 186.8757, 186.2372, 186.5491, 186.4879, 186.0404, 186.4170,
         186.7651, 186.8277, 186.5547],
        [186.5084, 186.7926, 186.1541, 186.4685, 186.4057, 185.9590, 186.3355,
         186.6839, 186.7469, 186.4725],
        [186.6370, 186.9211, 186.2830, 186.5923, 186.5349, 186.0879, 186.4637,
         186.8125, 186.8766, 186.5999],
        [186.6755, 186.9610, 186.3226, 186.6327, 186.5742, 186.1249, 186.5025,
         186.8497, 186.9131, 186.6389],
        [186.3490, 186.6358, 185.9953, 186.3059, 186.2483, 185.8003, 186.1749,
         186.5254, 186.5887, 186.3132],
        [186.5918, 186.8760, 186.2378, 186.5470, 186.4899, 186.0426, 186.4184,
     

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.5310, 186.4665, 186.7005, 187.0757, 186.6544, 186.7381, 186.0434,
         186.3782, 185.8834, 186.3639],
        [186.5656, 186.5004, 186.7325, 187.1091, 186.6898, 186.7711, 186.0785,
         186.4105, 185.9157, 186.3981],
        [186.6609, 186.5964, 186.8281, 187.2055, 186.7855, 186.8669, 186.1739,
         186.5069, 186.0117, 186.4947],
        [186.5959, 186.5311, 186.7624, 187.1388, 186.7192, 186.8011, 186.1089,
         186.4409, 185.9469, 186.4276],
        [186.6046, 186.5403, 186.7742, 187.1493, 186.7284, 186.8118, 186.1175,
         186.4519, 185.9577, 186.4378],
        [186.3352, 186.2699, 186.5027, 186.8799, 186.4579, 186.5403, 185.8464,
         186.1806, 185.6829, 186.1665],
        [186.5179, 186.4536, 186.6852, 187.0637, 186.6408, 186.7233, 186.0294,
         186.3646, 185.8671, 186.3509],
        [186.6111, 186

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.5726, 187.0652, 186.7044, 186.5218, 186.8765, 187.1695, 186.7541,
         187.5873, 186.9479, 186.7403],
        [186.1983, 186.6906, 186.3307, 186.1481, 186.5050, 186.7976, 186.3797,
         187.2141, 186.5749, 186.3703],
        [186.3771, 186.8696, 186.5090, 186.3266, 186.6824, 186.9751, 186.5586,
         187.3922, 186.7532, 186.5478],
        [186.5567, 187.0491, 186.6885, 186.5060, 186.8605, 187.1535, 186.7382,
         187.5711, 186.9319, 186.7242],
        [186.6613, 187.1516, 186.7927, 186.6087, 186.9654, 187.2590, 186.8418,
         187.6760, 187.0359, 186.8310],
        [186.4912, 186.9853, 186.6232, 186.4407, 186.7969, 187.0896, 186.6734,
         187.5074, 186.8680, 186.6614],
        [186.5841, 187.0739, 186.7160, 186.5322, 186.8881, 187.1818, 186.7644,
         187.5986, 186.9580, 186.7520],
        [186.3913, 186.8832, 186.5231, 186.3407, 186.69

       grad_fn=<CdistBackward0>)
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
Epoch: 603 Training Loss: 13.999058961868286 Training Accuracy: 0.1
	Learning_rate: 1e-06
tensor([[186.6628, 186.1359, 186.1207, 186.2925, 185.8751, 186.5386, 186.2755,
         186.3767, 186.3513, 186.0591],
        [186.9230, 186.3982, 186.3816, 186.5555, 186.1361, 186.8000, 186.5380,
         186.6386, 186.6129, 186.3206],
        [187.0213, 186.4931, 186.4786, 186.6483, 186.2307, 186.8963, 186.6308,
         186.7331, 186.7097, 186.4148],
        [186.8602, 186.3331, 186.3190, 186.4889, 186.0714, 186.7366, 186.4719,
         186.5734, 186.5496, 186.2552],
        [187.0141, 186.4881, 186.4729, 186.6442, 186.2256, 186.8908, 186.6268,
         186.7281, 186.7040, 186.4097],
        [186.9788, 186.4503, 186.4358, 186.6052, 186.1879, 186.8534, 186.5878,
         186.6902, 186.6669, 186.3722],
        [186.8407, 186.3123, 186.2966, 186.4682, 18

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.6410, 186.5714, 186.5522, 186.4692, 187.0651, 187.2888, 186.5265,
         185.8232, 186.0435, 186.6943],
        [186.6894, 186.6196, 186.6010, 186.5173, 187.1149, 187.3380, 186.5760,
         185.8730, 186.0932, 186.7429],
        [186.6753, 186.6056, 186.5867, 186.5033, 187.1007, 187.3233, 186.5616,
         185.8581, 186.0785, 186.7284],
        [186.6281, 186.5582, 186.5392, 186.4548, 187.0554, 187.2767, 186.5156,
         185.8121, 186.0320, 186.6809],
        [186.6158, 186.5453, 186.5273, 186.4421, 187.0434, 187.2660, 186.5043,
         185.8011, 186.0206, 186.6681],
        [186.6990, 186.6291, 186.6106, 186.5269, 187.1241, 187.3475, 186.5854,
         185.8823, 186.1027, 186.7523],
        [186.8146, 186.7445, 186.7263, 186.6427, 187.2411, 187.4635, 186.7020,
         185.9979, 186.2188, 186.8662],
        [186.7250, 186.6553, 186.6363, 186.5533, 187.15

       grad_fn=<CdistBackward0>)
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.4389, 186.7794, 186.6111, 186.1121, 186.3990, 186.3487, 186.7963,
         186.9203, 186.2902, 186.6907],
        [186.5143, 186.8500, 186.6821, 186.1861, 186.4713, 186.4230, 186.8672,
         186.9917, 186.3621, 186.7605],
        [186.5066, 186.8441, 186.6770, 186.1792, 186.4653, 186.4157, 186.8622,
         186.9863, 186.3568, 186.7561],
        [186.2508, 186.5916, 186.4219, 185.9235, 186.2097, 186.1599, 186.6072,
         186.7322, 186.1008, 186.5004],
        [186.3852, 186.7238, 186.5540, 186.0574, 186.3430, 186.2943, 186.7395,
         186.8633, 186.2334, 186.6331],
        [186.3862, 186.7220, 186.5560, 186.0584, 186.3435, 186.2941, 186.7411,
         186.8666, 186.2361, 186.6340],
        [186.4366, 186.7737, 186.6081, 186.1094, 186.3952, 186.3450, 186.7929,
         186.9191, 186.2881, 186.6859],
        [186.1247, 186

       grad_fn=<CdistBackward0>)
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.4146, 186.4610, 186.1103, 186.8959, 186.6135, 186.4214, 186.5298,
         186.3547, 186.7722, 186.4941],
        [186.2349, 186.2764, 185.9275, 186.7124, 186.4303, 186.2383, 186.3472,
         186.1698, 186.5878, 186.3130],
        [186.4485, 186.4950, 186.1440, 186.9299, 186.6472, 186.4552, 186.5638,
         186.3884, 186.8060, 186.5277],
        [186.1004, 186.1420, 185.7938, 186.5789, 186.2961, 186.1042, 186.2124,
         186.0365, 186.4541, 186.1794],
        [186.3423, 186.3871, 186.0368, 186.8218, 186.5400, 186.3480, 186.4570,
         186.2801, 186.6979, 186.4215],
        [186.2474, 186.2918, 185.9422, 186.7296, 186.4449, 186.2533, 186.3602,
         186.1865, 186.6038, 186.3276],
        [186.2941, 186.3400, 185.9900, 186.7760, 186.4926, 186.3008, 186.4086,
         186.2346, 186.6519, 186.3741],
        [186.4552, 186

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.2900, 186.5356, 187.3517, 186.5468, 186.6637, 186.0887, 187.0038,
         186.3828, 186.8532, 187.2804],
        [186.1367, 186.3808, 187.1978, 186.3920, 186.5089, 185.9355, 186.8500,
         186.2281, 186.6990, 187.1274],
        [186.2268, 186.4711, 187.2861, 186.4800, 186.5982, 186.0236, 186.9411,
         186.3181, 186.7912, 187.2178],
        [186.1279, 186.3741, 187.1893, 186.3844, 186.5026, 185.9256, 186.8426,
         186.2212, 186.6918, 187.1189],
        [186.2686, 186.5132, 187.3284, 186.5216, 186.6384, 186.0673, 186.9831,
         186.3607, 186.8324, 187.2595],
        [186.0872, 186.3327, 187.1471, 186.3415, 186.4603, 185.8841, 186.8023,
         186.1798, 186.6516, 187.0786],
        [186.2522, 186.4950, 187.3125, 186.5070, 186.6241, 186.0497, 186.9645,
         186.3420, 186.8148, 187.2426],
        [186.1462, 186.3903, 187.2057, 186.3987, 186.51

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
tensor([[186.7200, 186.2705, 186.4414, 186.3807, 187.2976, 186.6452, 186.2807,
         186.3031, 186.2393, 186.6530],
        [186.6682, 186.2196, 186.3899, 186.3287, 187.2460, 186.5937, 186.2286,
         186.2510, 186.1875, 186.6011],
        [186.8841, 186.4370, 186.6054, 186.5455, 187.4614, 186.8105, 186.4443,
         186.4663, 186.4042, 186.8175],
        [186.7192, 186.2727, 186.4424, 186.3811, 187.2975, 186.6461, 186.2800,
         186.3015, 186.2405, 186.6530],
        [186.6728, 186.2256, 186.3960, 186.3353, 187.2505, 186.6000, 186.2334,
         186.2553, 186.1945, 186.6073],
        [186.6142, 186.1675, 186.3372, 186.2767, 187.1917, 186.5420, 186.1741,
         186.1966, 186.1357, 186.5489],
        [186.3840, 185.9366, 186.1078, 186.0456, 186.9622, 186.3113, 185.9440,
         185.9667, 185.9049, 186.3181],
        [186.7628, 186.3134, 186.4839, 186.4235, 187.34

KeyboardInterrupt: 