In [4]:
import torch as ch
import torchvision as tv

from ffcv.loader import Loader, OrderOption
from ffcv.transforms import (
    ToTensor,
    ToDevice,
    Squeeze,
    NormalizeImage,
    RandomHorizontalFlip,
    ToTorchImage,
    Convert
)
from ffcv.fields.rgb_image import (
    CenterCropRGBImageDecoder,
    RandomResizedCropRGBImageDecoder,
    SimpleRGBImageDecoder
)
from ffcv.fields.basics import IntDecoder, FloatDecoder
from tqdm import tqdm
import numpy as np

from ffcv.pipeline.operation import Operation, AllocationQuery
from ffcv.pipeline.compiler import Compiler
from abc import abstractmethod

from typing import Callable, Tuple, Optional
from ffcv.pipeline.state import State
from dataclasses import replace

In [214]:
mean = ch.load("/home/soroush1/projects/def-kohitij/soroush1/pretrain-imagenet/datasets/LaMem/support_files/image_mean_rgb.pt")
rnd_tensor = ch.rand(3, 256, 256)
print(f"{ch.equal(rnd_tensor, rnd_tensor)}")

print(f"{rnd_tensor[0, :10, 0]}")
print(f"{mean[0, :10, 0]}")
print(f"{(rnd_tensor - mean)[0, :10, 0]}")
print(f"{ch.equal(rnd_tensor, rnd_tensor - mean)}")
mean.size(), rnd_tensor.size()

True
tensor([0.2441, 0.1323, 0.7990, 0.3364, 0.2475, 0.6902, 0.4146, 0.7908, 0.5983,
        0.9590])
tensor([117.6884, 117.9610, 118.1297, 118.2820, 118.4500, 118.5726, 118.7102,
        118.8694, 118.9920, 119.0743])
tensor([-117.4443, -117.8287, -117.3307, -117.9456, -118.2025, -117.8824,
        -118.2956, -118.0786, -118.3937, -118.1153])
False


(torch.Size([3, 256, 256]), torch.Size([3, 256, 256]))

In [215]:
import torch

# Create example tensors
batch_size = 10
channels = 3
height = 32
width = 32

# Tensor of shape [batch, channels, h, w]
tensor_a = torch.randn(batch_size, channels, height, width)

# Tensor of shape [channels, h, w]
tensor_b = torch.randn(channels, height, width)

# Subtract tensor_b from tensor_a
result = tensor_a - tensor_b

print(result.shape)  # Should be [batch_size, channels, height, width]
print(result[0, 0, :10, 0])

torch.Size([10, 3, 32, 32])
tensor([-1.5850, -1.1977, -1.5490,  2.3488, -2.0080,  1.0427, -0.1712, -0.5773,
        -1.4751, -1.9090])


In [None]:
tensor_a.detach().cpu().numpy()

In [119]:
for i in range(tensor_a.shape[0]):

    result = tensor_a[i] - tensor_b
    print(result.shape)  # Should be [batch_size, channels, height, width]
    print(result[0, :10, 0])
    
    if i == 0:
        break

torch.Size([3, 32, 32])
tensor([-1.6084,  0.1366,  1.5696, -1.3446, -2.0777,  0.1969, -0.2545, -2.2963,
         0.6162, -0.7772])


In [184]:
np.uint8

numpy.uint8

In [220]:
class CustomNormalize(Operation):

    def __init__(self, mean):
        self.mean = mean.cpu().numpy() if isinstance(mean, torch.Tensor) else mean

    # Return the code to run this operation
    # @abstractmethod
    def generate_code(self) -> Callable:
        parallel_range = Compiler.get_iterator()
        mean = self.mean  # Capture mean as a local variable

        def subtract_mean(images, dst):
            for i in parallel_range(images.shape[0]):
                # print(f"Now here")
                # print(f"{mean.shape = }")
                # print(f"{images[i].size() = }")
                # print(f"{dst[i].size() = }")
                dst[i] = images[i] - mean

                if i == 0:
                    print(f"{(images[i] - mean)[0, :10, 0] = }")

            print(f"{dst[0, 0, :10, 0] = }")
            return dst
        subtract_mean.is_parallel = True
        return subtract_mean

    # @abstractmethod
    def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
        print(f"{previous_state.shape = }")
        c, h, w = previous_state.shape

        new_shape = (c, h, w)

        new_state = replace(previous_state, shape=new_shape)

        mem_allocation = AllocationQuery(new_shape, previous_state.dtype)
        
        return (new_state, mem_allocation)

class ToNumpy(Operation):
    # Return the code to run this operation
    # @abstractmethod
    def generate_code(self) -> Callable:

        def to_numpy(images, dst):
            # print(f"{images.size() = }")
            images = images.detach().cpu().numpy()
            images = np.transpose(images, (0, 2, 3, 1))
            # print(f"{images.dtype = }")
            # print(f"{type(images) = }")
            # return np.transpose(images.detach().cpu().numpy(), (0, 2, 3, 1)) # [20, 3, 256, 256] -> [20, 256, 256, 3]
            return images # [20, 3, 256, 256] -> [20, 256, 256, 3]
            
        return to_numpy

    # @abstractmethod
    def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
        c, h, w = previous_state.shape
        new_shape = (h, w, c)
    
        # Everything in the state stays the same other than the shape
        # States are immutable, so we have to edit them using the
        # dataclasses.replace function
        new_state = replace(previous_state, jit_mode=False, shape=new_shape, dtype=np.uint8)
        print(f"{new_state = }")

        # print(f"{previous_state.dtype = }")
        # We need to allocate memory for the new images
        # so below, we ask for a memory allocation whose width and height is
        # half the original image, with the same type
        # (shape=(,)) of the same type as the image data
        mem_allocation = AllocationQuery(new_shape, np.uint8)
        return (new_state, mem_allocation)


In [221]:
print("everything is loaded completely")

train_dataset = "/home/soroush1/projects/def-kohitij/soroush1/training_fast_publish_faster/data/lamem_train_256.ffcv"
num_workers = 1
batch_size = 20
distributed = 0
in_memory = True
this_device = "cuda:0"

res = 256
ratio = 256 / 256

LAMEM_MEAN = ch.load("/home/soroush1/projects/def-kohitij/soroush1/pretrain-imagenet/datasets/LaMem/support_files/image_mean_rgb.pt")
normalize = CustomNormalize(LAMEM_MEAN)
convert_to_numpy = ToNumpy()

center_crop_decoder = CenterCropRGBImageDecoder((256, 256), ratio)
random_resize_crop_decoder = RandomResizedCropRGBImageDecoder((res, res))

image_pipeline = [
    center_crop_decoder,
    ToTensor(),
    ToTorchImage(),
    # random_resize_crop_decoder,
    normalize,
    ToDevice(ch.device(this_device), non_blocking=True),
    Convert(ch.float16),
]

label_pipeline = [
    FloatDecoder(),
    ToTensor(),
    Squeeze(),
    ToDevice(ch.device(this_device), non_blocking=True),
]

order = OrderOption.RANDOM if distributed else OrderOption.QUASI_RANDOM

loader = Loader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    order=order,
    os_cache=in_memory,
    drop_last=True,
    pipelines={"image": image_pipeline, "label": label_pipeline},
    distributed=distributed,
)

for i, (ims, labs) in tqdm(enumerate(loader)):
    print(f"{i}: {labs = }")
    print(f"{type(ims) = }")
    if isinstance(ims, np.ndarray):
        print(f"{i}: {ims.shape = }")
        print(f"{i}: {ims.dtype = }")

    else:
        print(f"{ims[0, 0, :10, 0] = }")
        print(f"{i}: {ims.size() = }")
        print(f"{i}: {ims.float().mean() = }")

    # with autocast():
    #     clf, reg = multi_head_model(ims)

    # print(f"{clf.size() = }")
    # print(f"{reg.size() = }")

    if i == 5:
        break

everything is loaded completely
previous_state.shape = (3, 256, 256)
previous_state.shape = (3, 256, 256)


0it [00:00, ?it/s]

(images[i] - mean)[0, :10, 0] = tensor([13.3116,  8.0390,  5.8703,  5.7180, 10.5500,  9.4274,  8.2898,  5.1306,
         4.0080,  4.9257])
dst[0, 0, :10, 0] = tensor([13,  8,  5,  5, 10,  9,  8,  5,  4,  4], dtype=torch.uint8)
0: labs = tensor([0.8043, 0.7907, 0.8205, 0.8684, 0.8667, 0.9211, 0.9111, 0.5556, 0.7234,
        0.5429, 0.7500, 0.8333, 0.9070, 0.6875, 0.9070, 0.8000, 0.6667, 0.6122,
        0.5385, 0.8140], device='cuda:0', dtype=torch.float64)
type(ims) = <class 'torch.Tensor'>
ims[0, 0, :10, 0] = tensor([13.,  8.,  5.,  5., 10.,  9.,  8.,  5.,  4.,  4.], device='cuda:0',
       dtype=torch.float16)
0: ims.size() = torch.Size([20, 3, 256, 256])


1it [00:01,  1.63s/it]

0: ims.float().mean() = tensor(123.5779, device='cuda:0')


2it [00:01,  1.18it/s]

(images[i] - mean)[0, :10, 0] = tensor([137.3116, 136.0390, 134.8703, 132.7180, 131.5500, 130.4274, 130.2898,
        129.1306, 129.0080, 128.9257])
dst[0, 0, :10, 0] = tensor([137, 136, 134, 132, 131, 130, 130, 129, 129, 128], dtype=torch.uint8)
1: labs = tensor([0.6905, 0.7000, 0.7619, 0.8667, 0.6944, 0.8750, 0.6429, 0.8049, 0.7949,
        0.5581, 0.7805, 0.8095, 0.9091, 0.7353, 0.6053, 0.7447, 0.8286, 0.6591,
        0.5000, 0.6000], device='cuda:0', dtype=torch.float64)
type(ims) = <class 'torch.Tensor'>
ims[0, 0, :10, 0] = tensor([137., 136., 134., 132., 131., 130., 130., 129., 129., 128.],
       device='cuda:0', dtype=torch.float16)
1: ims.size() = torch.Size([20, 3, 256, 256])
1: ims.float().mean() = tensor(141.2969, device='cuda:0')


3it [00:02,  1.45it/s]

(images[i] - mean)[0, :10, 0] = tensor([-111.6884, -111.9610, -112.1297, -112.2820, -112.4500, -112.5726,
        -112.7102, -112.8694, -112.9920, -113.0743])
dst[0, 0, :10, 0] = tensor([145, 145, 144, 144, 144, 144, 144, 144, 144, 143], dtype=torch.uint8)
2: labs = tensor([0.7500, 0.8095, 0.8718, 0.7273, 0.7568, 0.8621, 0.5000, 0.8723, 0.7442,
        0.4000, 0.8696, 0.5294, 0.7429, 0.8108, 0.6977, 0.6364, 0.8235, 0.9730,
        0.6512, 0.7561], device='cuda:0', dtype=torch.float64)
type(ims) = <class 'torch.Tensor'>
ims[0, 0, :10, 0] = tensor([145., 145., 144., 144., 144., 144., 144., 144., 144., 143.],
       device='cuda:0', dtype=torch.float16)
2: ims.size() = torch.Size([20, 3, 256, 256])
2: ims.float().mean() = tensor(129.6190, device='cuda:0')


4it [00:02,  1.62it/s]

(images[i] - mean)[0, :10, 0] = tensor([-41.6884, -43.9610, -46.1297, -49.2820, -50.4500, -50.5726, -47.7102,
        -46.8694, -44.9920, -45.0743])
dst[0, 0, :10, 0] = tensor([215, 213, 210, 207, 206, 206, 209, 210, 212, 211], dtype=torch.uint8)
3: labs = tensor([0.7273, 0.7838, 0.5938, 0.8108, 0.5122, 0.7818, 0.6857, 0.8750, 0.9677,
        0.6410, 0.7895, 0.8462, 0.5208, 0.8718, 0.8108, 0.8421, 0.9750, 0.7317,
        0.7500, 0.6585], device='cuda:0', dtype=torch.float64)
type(ims) = <class 'torch.Tensor'>
ims[0, 0, :10, 0] = tensor([215., 213., 210., 207., 206., 206., 209., 210., 212., 211.],
       device='cuda:0', dtype=torch.float16)
3: ims.size() = torch.Size([20, 3, 256, 256])
3: ims.float().mean() = tensor(131.2811, device='cuda:0')


5it [00:03,  1.73it/s]

(images[i] - mean)[0, :10, 0] = tensor([125.3116, 125.0390, 124.8703, 126.7180, 126.5500, 127.4274, 123.2898,
        122.1306, 122.0080, 122.9257])
dst[0, 0, :10, 0] = tensor([125, 125, 124, 126, 126, 127, 123, 122, 122, 122], dtype=torch.uint8)
4: labs = tensor([0.8000, 0.6579, 0.8611, 0.6410, 0.8667, 0.5366, 0.8718, 0.8810, 0.5952,
        0.8611, 0.8571, 0.9429, 0.8372, 0.8537, 0.7885, 0.6970, 0.8065, 0.5641,
        0.7805, 0.7297], device='cuda:0', dtype=torch.float64)
type(ims) = <class 'torch.Tensor'>
ims[0, 0, :10, 0] = tensor([125., 125., 124., 126., 126., 127., 123., 122., 122., 122.],
       device='cuda:0', dtype=torch.float16)
4: ims.size() = torch.Size([20, 3, 256, 256])
4: ims.float().mean() = tensor(137.3767, device='cuda:0')


5it [00:03,  1.31it/s]

(images[i] - mean)[0, :10, 0] = tensor([ -4.6884,  -8.9610, -13.1297, -63.2820, -72.4500, -82.5726, -12.7102,
        -28.8694, -47.9920,  -1.0743])
dst[0, 0, :10, 0] = tensor([252, 248, 243, 193, 184, 174, 244, 228, 209, 255], dtype=torch.uint8)
5: labs = tensor([0.9730, 0.6667, 0.7500, 0.6591, 0.7188, 0.8723, 0.9231, 0.7222, 0.5946,
        0.7353, 0.6591, 0.7368, 0.4737, 0.6897, 0.9189, 0.7179, 0.5128, 0.8421,
        0.7838, 0.7297], device='cuda:0', dtype=torch.float64)
type(ims) = <class 'torch.Tensor'>
ims[0, 0, :10, 0] = tensor([252., 248., 243., 193., 184., 174., 244., 228., 209., 255.],
       device='cuda:0', dtype=torch.float16)
5: ims.size() = torch.Size([20, 3, 256, 256])
5: ims.float().mean() = tensor(131.4686, device='cuda:0')





(images[i] - mean)[0, :10, 0] = tensor([2.3116, 2.0390, 1.8703, 2.7180, 2.5500, 3.4274, 4.2898, 4.1306, 4.0080,
        4.9257])
dst[0, 0, :10, 0] = tensor([2, 2, 1, 2, 2, 3, 4, 4, 4, 4], dtype=torch.uint8)
(images[i] - mean)[0, :10, 0] = tensor([13.3116, 13.0390, 12.8703, 12.7180, 13.5500, 13.4274, 12.2898, 12.1306,
        12.0080, 13.9257])
dst[0, 0, :10, 0] = tensor([13, 13, 12, 12, 13, 13, 12, 12, 12, 13], dtype=torch.uint8)
(images[i] - mean)[0, :10, 0] = tensor([-101.6884, -100.9610,  -98.1297,  -95.2820,  -90.4500,  -88.5726,
         -86.7102,  -85.8694,  -86.9920,  -88.0743])
dst[0, 0, :10, 0] = tensor([155, 156, 158, 161, 166, 168, 170, 171, 170, 168], dtype=torch.uint8)
(images[i] - mean)[0, :10, 0] = tensor([17.3116, 17.0390, 19.8703, 16.7180, 13.5500, 10.4274,  8.2898, -2.8694,
        -4.9920, -3.0743])
dst[0, 0, :10, 0] = tensor([ 17,  17,  19,  16,  13,  10,   8, 254, 252, 253], dtype=torch.uint8)


In [1]:
tensor_a = [0.5078, 0.5435, 0.5146, 0.5420, 0.5796, 0.5396, 0.5259, 0.5244, 0.5205,
        0.5269, 0.5225, 0.5605, 0.5132, 0.5312, 0.5479, 0.5444, 0.5630, 0.5718,
        0.5254, 0.5283, 0.5317, 0.5308, 0.5298, 0.5435, 0.5234, 0.5557, 0.5181,
        0.5381, 0.5352, 0.5322, 0.5293, 0.6719, 0.5234, 0.6313, 0.5107, 0.5220,
        0.5229, 0.5352, 0.5439, 0.6162, 0.5483, 0.5200, 0.5254, 0.5815, 0.5342,
        0.5337, 0.5049, 0.5293, 0.5430, 0.5366, 0.5352, 0.5459, 0.5405, 0.5190,
        0.5332, 0.5488, 0.5312, 0.5220, 0.5278, 0.5229, 0.5273, 0.5264, 0.5542,
        0.4927, 0.5532, 0.5386, 0.5283, 0.5166, 0.5371, 0.5259, 0.5322, 0.5366,
        0.5381, 0.5132, 0.5278, 0.5317, 0.5430, 0.5361, 0.5366, 0.4675, 0.5913,
        0.6084, 0.4988, 0.5391, 0.5356, 0.5200, 0.5361, 0.5273, 0.5225, 0.5361,
        0.5015, 0.5176, 0.5327, 0.5078, 0.5293, 0.5225, 0.5186, 0.5312, 0.5312,
        0.4829, 0.5254, 0.5483, 0.5400, 0.5298, 0.5327, 0.5400, 0.5435, 0.5220,
        0.5273, 0.5244, 0.5195, 0.5664, 0.5405, 0.5742, 0.5430, 0.5376, 0.5244,
        0.5649, 0.5718, 0.5420, 0.5278, 0.5317, 0.5234, 0.5293, 0.5210, 0.5308,
        0.5303, 0.5347, 0.5425, 0.5234, 0.5215, 0.5317, 0.5195, 0.5137, 0.5220,
        0.5112, 0.5210, 0.5166, 0.5015, 0.5503, 0.5298, 0.5112, 0.5283, 0.5337,
        0.5439, 0.5117, 0.5269, 0.5293, 0.5059, 0.5571, 0.5181, 0.5464, 0.4893,
        0.5391, 0.5557, 0.5205, 0.5366, 0.5059, 0.5137, 0.5361, 0.5264, 0.5352,
        0.5391, 0.5361, 0.5361, 0.4890, 0.5327, 0.6304, 0.5425, 0.5269, 0.5444,
        0.5371, 0.5327, 0.5283, 0.5371, 0.5166, 0.4783, 0.5513, 0.5176, 0.5391,
        0.5288, 0.5322, 0.5273, 0.5522, 0.5415, 0.5854, 0.5264, 0.5298, 0.5391,
        0.5112, 0.5391, 0.5317, 0.5386, 0.5449, 0.5522, 0.5483, 0.5371, 0.5283,
        0.5200, 0.5371, 0.5122, 0.6016, 0.5273, 0.5210, 0.5278, 0.5244, 0.5449,
        0.5474, 0.5283, 0.6738, 0.5234, 0.5864, 0.6055, 0.5181, 0.5464, 0.5942,
        0.5688, 0.5278, 0.6743, 0.6499, 0.5288, 0.5205, 0.5098, 0.5356, 0.5254,
        0.5347, 0.4988, 0.5303, 0.5151, 0.4995, 0.5186, 0.5483, 0.5503, 0.5103,
        0.5439, 0.5210, 0.5352, 0.5430, 0.5244, 0.5361, 0.5239, 0.5503, 0.5283,
        0.5337, 0.5464, 0.5273, 0.5278, 0.5420, 0.5386, 0.5190, 0.5400, 0.5215,
        0.4338, 0.5518, 0.5376, 0.5029, 0.5249, 0.5229, 0.5195, 0.5288, 0.5044,
        0.5166, 0.5464, 0.5234, 0.5269, 0.5142, 0.5073, 0.4878, 0.5054, 0.5254,
        0.5273, 0.5264, 0.5127, 0.5371, 0.5532, 0.5308, 0.5518, 0.5200, 0.5205,
        0.4954, 0.5308, 0.5386, 0.5156, 0.5693, 0.5386, 0.5347, 0.5366, 0.5151,
        0.5444, 0.8726, 0.5308, 0.4556, 0.5273, 0.5254, 0.5078, 0.5327, 0.5239,
        0.5503, 0.5317, 0.5278, 0.5254, 0.5615, 0.5317, 0.4978, 0.5200, 0.5342,
        0.5312, 0.5283, 0.5767, 0.5205, 0.5137, 0.5254, 0.5420, 0.5288, 0.6333,
        0.5430, 0.5508, 0.5508, 0.5186, 0.5273, 0.5278, 0.5205, 0.5439, 0.5405,
        0.5361, 0.5186, 0.5762, 0.5415, 0.5332, 0.5278, 0.4819, 0.5522, 0.5098,
        0.5366, 0.5269, 0.5063, 0.5405, 0.5273, 0.5239, 0.5220, 0.5254, 0.5464,
        0.5420, 0.5396, 0.5015, 0.5562, 0.5386, 0.5522, 0.5835, 0.5410, 0.5005,
        0.5444, 0.5107, 0.5142, 0.5249, 0.4648, 0.5171, 0.5327, 0.5132, 0.5347,
        0.5210, 0.5410, 0.5327, 0.4978, 0.5400, 0.5649, 0.5742, 0.5415, 0.5142,
        0.5342, 0.5312, 0.5508, 0.5625, 0.5210, 0.5342, 0.5444, 0.5190, 0.5093,
        0.5068, 0.5425, 0.5234, 0.4924, 0.5151, 0.5327, 0.4983, 0.5361, 0.5391,
        0.5122, 0.5117, 0.5488, 0.5376, 0.5366, 0.5425, 0.5244, 0.5186, 0.5239,
        0.5405, 0.5415, 0.5024, 0.5317, 0.5288, 0.5342, 0.5171, 0.5327, 0.5146,
        0.4861, 0.5210, 0.5884, 0.5757, 0.5444, 0.5356, 0.5098, 0.5283, 0.5278,
        0.5312, 0.5474, 0.5332, 0.5420, 0.6528, 0.5405, 0.5171, 0.5386, 0.5303,
        0.5200, 0.5410, 0.4727, 0.4993, 0.5381, 0.5288, 0.5200, 0.5298, 0.5806,
        0.5356, 0.5479, 0.4321, 0.5479, 0.5381, 0.5952, 0.5684, 0.5312, 0.5229,
        0.5225, 0.5410, 0.5308, 0.5322, 0.5454, 0.5225, 0.5444, 0.5420, 0.5215,
        0.5854, 0.5273, 0.5371, 0.4932, 0.5205, 0.5234, 0.5312, 0.5283, 0.5322,
        0.5132, 0.5376, 0.5308, 0.5244, 0.5220, 0.5332, 0.5439, 0.5044, 0.5356,
        0.5347, 0.5347, 0.5312, 0.5244, 0.5176, 0.5386, 0.5103, 0.5029, 0.5376,
        0.4495, 0.5503, 0.5522, 0.5254, 0.5400, 0.5459, 0.4963, 0.5166, 0.5703,
        0.5264, 0.5972, 0.5767, 0.5718, 0.5264, 0.5088, 0.4937, 0.5776, 0.5327,
        0.5391, 0.5493, 0.5264, 0.5347, 0.5059, 0.5386, 0.5249, 0.5317, 0.4978,
        0.5425, 0.5278, 0.5283, 0.4883, 0.5210, 0.5361, 0.5137, 0.5298]

In [2]:
tensor_b = [0.5000, 0.8718, 0.9130, 0.6750, 0.7241, 0.9394, 0.5714, 0.7826, 0.8378,
        0.4242, 0.9189, 0.9231, 0.6216, 0.9333, 0.9250, 0.7353, 0.6875, 0.8214,
        0.8235, 0.5366, 0.6176, 0.3590, 0.7391, 0.7727, 0.8409, 0.8571, 0.8049,
        0.7556, 0.8333, 0.9556, 0.6744, 0.7857, 0.9556, 0.6000, 0.6774, 0.8500,
        0.7179, 0.7442, 0.9459, 0.7879, 0.9286, 0.5476, 0.7660, 0.8649, 0.7500,
        0.5652, 0.6047, 0.9250, 0.8837, 0.8409, 0.8919, 0.9211, 0.8000, 0.7755,
        0.7105, 0.5833, 0.7353, 0.7826, 0.7000, 0.8182, 0.7568, 0.6286, 0.7297,
        0.8537, 0.7500, 0.7250, 0.8205, 0.8293, 0.7576, 0.4595, 0.6667, 0.7400,
        0.6923, 0.9000, 0.8378, 0.9268, 0.7442, 0.7838, 0.7436, 0.8438, 0.7778,
        0.8235, 0.5122, 0.6250, 0.7188, 0.8200, 0.7000, 0.7568, 0.6667, 0.8810,
        0.7949, 0.8857, 0.6667, 0.8000, 0.5161, 0.7250, 0.6316, 0.6250, 0.4390,
        0.7561, 0.7778, 0.9412, 0.8974, 0.5556, 0.9000, 0.5333, 0.8750, 0.8140,
        0.4571, 0.6190, 0.7027, 0.8085, 0.7500, 0.6000, 0.9048, 0.7750, 0.8478,
        0.8158, 0.8718, 0.8333, 0.5938, 0.7273, 0.8571, 0.7222, 0.4762, 0.9623,
        0.7400, 0.5814, 0.8857, 0.7000, 0.7436, 0.6154, 0.7805, 0.7556, 0.7179,
        0.8571, 0.8056, 0.8571, 0.6667, 0.6923, 0.8947, 0.6591, 0.7500, 0.8222,
        0.8500, 0.7000, 0.7429, 0.7609, 0.6571, 0.7436, 0.7021, 0.5476, 0.9643,
        0.7805, 0.7368, 0.8636, 0.7059, 0.7561, 0.9706, 0.6410, 0.5946, 0.7442,
        0.8667, 0.7568, 0.7209, 0.6346, 0.8293, 0.8974, 0.7561, 0.8000, 0.7500,
        0.8788, 0.5526, 0.9268, 0.6744, 0.6579, 0.8372, 0.9167, 0.8800, 0.6977,
        0.7647, 0.8378, 0.8780, 0.6000, 0.5862, 0.7632, 0.8710, 0.7297, 0.5476,
        0.7429, 0.8095, 0.4884, 0.5455, 0.7059, 0.7750, 0.5116, 0.7857, 0.8684,
        0.8750, 0.7895, 0.5000, 0.8913, 0.8182, 0.6905, 0.8276, 0.8095, 0.7805,
        0.5135, 0.6857, 0.7179, 0.6585, 0.9487, 0.9048, 0.8780, 0.8378, 0.7750,
        0.9535, 0.6452, 0.8085, 0.9189, 0.5946, 0.7234, 0.7805, 0.8065, 0.7368,
        0.4857, 0.8378, 0.6829, 0.9714, 0.8085, 0.7143, 0.6667, 0.9706, 0.8182,
        0.5714, 0.9412, 0.8000, 0.8250, 0.7949, 0.7105, 0.8649, 0.5526, 0.8222,
        0.9091, 0.8621, 0.7895, 0.6889, 0.8462, 0.9250, 0.6190, 0.7805, 0.3548,
        0.6279, 0.7742, 0.6500, 0.6444, 0.7907, 0.8710, 0.8000, 0.7727, 0.7805,
        0.8542, 0.5294, 0.8696, 0.6279, 0.7576, 0.9211, 0.8108, 0.8718, 0.9118,
        0.6512, 0.7222, 0.8718, 0.6857, 0.5909, 0.7727, 0.9189, 0.9444, 0.7500,
        0.9000, 0.6944, 0.7179, 0.7568, 0.9375, 0.9091, 0.8919, 0.8421, 0.5882,
        0.8864, 0.9737, 0.6667, 0.8205, 0.6875, 0.9118, 0.8696, 0.8824, 0.7442,
        0.8056, 0.7778, 0.9070, 0.9268, 0.7179, 0.8085, 0.4737, 0.8333, 0.7778,
        0.5429, 0.7568, 0.9722, 0.8947, 0.6944, 0.8293, 0.8857, 0.6061, 0.7083,
        0.8409, 0.7188, 0.8140, 0.8684, 0.7660, 0.8605, 0.7500, 0.8095, 0.6667,
        0.7419, 0.8857, 0.7895, 0.8571, 0.9167, 0.7174, 0.6765, 0.8205, 0.7045,
        0.7143, 0.7500, 0.8049, 0.8947, 0.7551, 0.7429, 0.9130, 0.8529, 0.4524,
        0.7347, 0.9744, 0.8571, 0.7692, 0.9750, 0.7317, 0.8537, 0.6842, 0.4872,
        0.7619, 0.8409, 0.7143, 0.6809, 0.7021, 0.8000, 0.6383, 0.4444, 0.8387,
        0.6842, 0.9545, 0.7000, 0.8000, 0.4722, 0.7800, 0.7907, 0.8043, 0.7576,
        0.8182, 0.5526, 0.8974, 0.6875, 0.8500, 0.7368, 0.6809, 0.8158, 0.6190,
        0.6765, 0.8000, 0.9091, 0.7234, 0.7857, 0.8542, 0.5263, 0.7714, 0.7234,
        0.7667, 0.7143, 0.5814, 0.5556, 0.4737, 0.9070, 0.6136, 0.9167, 0.6875,
        0.7391, 0.6410, 1.0000, 0.8158, 0.6000, 0.6364, 0.9149, 0.8286, 0.8261,
        0.7561, 0.8387, 0.6829, 0.6739, 0.7333, 0.7447, 0.5366, 0.8529, 0.6905,
        0.7727, 0.8250, 0.6667, 0.8205, 0.7500, 0.8043, 0.5116, 0.8421, 0.8857,
        0.7907, 0.8684, 0.6000, 0.8537, 0.8611, 0.7632, 0.8974, 0.8780, 0.6591,
        0.5556, 0.8286, 0.8158, 0.7234, 0.8500, 0.8333, 0.7179, 0.8158, 0.7381,
        0.6905, 0.7000, 0.6944, 0.5128, 0.8857, 0.7143, 0.8947, 0.6486, 0.7333,
        0.8605, 0.8750, 0.8140, 0.7436, 0.6667, 0.7872, 0.8537, 0.6042, 0.8222,
        0.8000, 0.7714, 0.8378, 0.6452, 0.9750, 0.6098, 0.6667, 0.8750, 0.7692,
        0.8889, 0.9333, 0.8718, 0.8182, 0.5897, 0.6136, 0.7778, 0.5250, 0.6250,
        0.7917, 0.7838, 0.8750, 0.7059, 0.4524, 0.7436, 0.6364, 0.7872, 0.8095,
        0.6000, 0.5750, 0.6512, 0.8649, 0.8378, 0.8140, 0.8205, 0.5682, 0.8108,
        0.8163, 0.8049, 0.7333, 0.8529, 0.6757, 0.9355, 0.7353, 1.0000, 0.6364,
        0.4688, 0.9474, 0.5854, 0.8864, 0.9333, 0.6591, 0.8000, 0.7447]

In [5]:
tensor_a = ch.Tensor(tensor_a)
tensor_b = ch.Tensor(tensor_b)


In [6]:
loss = ch.nn.MSELoss()

In [10]:
loss(tensor_a, tensor_b).float().dtype

torch.float32