In [1]:
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions.normal import Normal

### Old Utils

In [2]:
def axis_angle_to_matrix(ax, ay, az, device):
    sx = torch.sin(torch.tensor(ax, device=device))
    cx = torch.cos(torch.tensor(ax, device=device))
    sy = torch.sin(torch.tensor(ay, device=device))
    cy = torch.cos(torch.tensor(ay, device=device))
    sz = torch.sin(torch.tensor(az, device=device))
    cz = torch.cos(torch.tensor(az, device=device))
    Rx = torch.tensor([
        [1,   0,    0],
        [0,   cx,  -sx],
        [0,   sx,   cx]
    ], device=device, dtype=torch.float32)
    Ry = torch.tensor([
        [ cy,   0,  sy],
        [  0,   1,   0],
        [-sy,   0,  cy]
    ], device=device, dtype=torch.float32)
    Rz = torch.tensor([
        [ cz, -sz,  0],
        [ sz,  cz,  0],
        [  0,   0,  1]
    ], device=device, dtype=torch.float32)
    R = torch.matmul(Rz, torch.matmul(Ry, Rx))
    return R


def rotate_tensor(tensor, angle_rad, axes, device):
    B = tensor.size(0)
    dtype = tensor.dtype  
    R = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).repeat(B, 1, 1)  

    if axes == (2, 3):
        cos_a = torch.cos(angle_rad).type(dtype)
        sin_a = torch.sin(angle_rad).type(dtype)
        R[:,1,1] = cos_a
        R[:,1,2] = -sin_a
        R[:,2,1] = sin_a
        R[:,2,2] = cos_a
    elif axes == (2, 4):
        cos_a = torch.cos(angle_rad).type(dtype)
        sin_a = torch.sin(angle_rad).type(dtype)
        R[:,0,0] = cos_a
        R[:,0,2] = sin_a
        R[:,2,0] = -sin_a
        R[:,2,2] = cos_a
    elif axes == (3, 4):
        cos_a = torch.cos(angle_rad).type(dtype)
        sin_a = torch.sin(angle_rad).type(dtype)
        R[:,0,0] = cos_a
        R[:,0,1] = -sin_a
        R[:,1,0] = sin_a
        R[:,1,1] = cos_a
    else:
        raise ValueError("Invalid axes for rotation. Must be one of (2, 3), (2, 4), or (3, 4).")

    affine_matrix = torch.zeros(B, 3, 4, device=device, dtype=dtype)
    affine_matrix[:, :3, :3] = R
    affine_matrix[:, :3, 3] = 0 

    grid = F.affine_grid(affine_matrix, tensor.size(), align_corners=False)  

    if torch.isnan(grid).any() or torch.isinf(grid).any():
        print("NaNs or Infs detected in grid in rotate_tensor.")
        raise ValueError("Invalid grid in rotate_tensor.")

    rotated = F.grid_sample(tensor, grid, mode='bilinear', padding_mode='border', align_corners=False)
    return rotated

def shift_tensor(tensor, shift, device):
    B, C, D, H, W = tensor.size()
    shifts_normalized = shift.clone()
    shifts_normalized[:, 0] = shift[:, 0] * 2 / (W - 1)  
    shifts_normalized[:, 1] = shift[:, 1] * 2 / (H - 1)  
    shifts_normalized[:, 2] = shift[:, 2] * 2 / (D - 1) 
    affine_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(B, 1, 1)
    affine_matrix[:, :3, 3] = -shifts_normalized
    affine_matrices = affine_matrix[:, :3, :]
    grid = F.affine_grid(affine_matrices, tensor.size(), align_corners=True)

    shifted = F.grid_sample(
        tensor, grid,
        mode='nearest',
        padding_mode='border',
        align_corners=True
    )
    return shifted

def shear_tensor(x, shear_factors, device):
    B, C, D, H, W = x.shape
    sheared = []
    for i in range(B):
        shear_x, shear_y, shear_z = shear_factors[i]
        affine_matrix = torch.tensor([
            [1, shear_x, shear_y, 0],
            [0, 1, shear_z, 0],
            [0, 0, 1, 0]
        ], dtype=torch.float32, device=device)
        grid = F.affine_grid(affine_matrix.unsqueeze(0), x[i:i+1].size(), align_corners=True)
        sheared_sample = F.grid_sample(x[i:i+1], grid, mode='bilinear', padding_mode='zeros', align_corners=True)
        sheared.append(sheared_sample)
    sheared = torch.cat(sheared, dim=0)
    return sheared
    
def compose_rotation(ax, ay, az, device):
    Rx = axis_angle_to_matrix(ax, 0.0, 0.0, device)
    Ry = axis_angle_to_matrix(0.0, ay, 0.0, device)
    Rz = axis_angle_to_matrix(0.0, 0.0, az, device)
    R_approx = Rz @ (Ry @ Rx)
    U, _, V = torch.linalg.svd(R_approx)
    det_val = torch.linalg.det(torch.matmul(U, V.transpose(-2, -1)))
    if det_val < 0:
        V[..., -1] *= -1.0
    R = torch.matmul(U, V.transpose(-2, -1))
    return R

def r6_to_matrix(r6):
    if r6.size(-1) != 6:
        raise ValueError(f"r6_to_matrix expects last dimension=6, got {r6.size(-1)}")
    with torch.cuda.amp.autocast(enabled=False):
        eps = 1e-7
        v1 = r6[..., :3]
        v1_norm = torch.norm(v1, dim=-1, keepdim=True)
        v1 = v1 / (v1_norm + eps)
        v2 = r6[..., 3:]
        v2 = v2 - torch.sum(v2 * v1, dim=-1, keepdim=True) * v1
        v2_norm = torch.norm(v2, dim=-1, keepdim=True)
        v2 = v2 / (v2_norm + eps)
        v3 = torch.cross(v1, v2, dim=-1)
        R = torch.stack([v1, v2, v3], dim=-2)
        det = torch.linalg.det(R)
        v3 = torch.where(det.unsqueeze(-1) < 0, -v3, v3)
        R = torch.stack([v1, v2, v3], dim=-2)

    return R

def matrix_to_r6(R):
    v1 = R[..., 0]  
    v2 = R[..., 1]  
    return torch.cat([v1, v2], dim=-1)


def apply_pretrain_transformations(x, device, max_rotation_angle, translation_range, shearing_range, args):
    B, C, D, H, W = x.shape

    angles_deg = torch.FloatTensor(B, 3).uniform_(-max_rotation_angle, max_rotation_angle).to(device)
    angles_rad = angles_deg * math.pi / 180.0

    x_rot = rotate_tensor(x, angles_rad[:, 0], axes=(2, 3), device=device)
    x_rot = rotate_tensor(x_rot, angles_rad[:, 1], axes=(2, 4), device=device)
    x_rot = rotate_tensor(x_rot, angles_rad[:, 2], axes=(3, 4), device=device)

    shear_factors = torch.FloatTensor(B, 3).uniform_(-shearing_range, shearing_range).to(device)
    x_sheared = shear_tensor(x_rot, shear_factors, device=device)

    max_trans_d = D * translation_range
    max_trans_h = H * translation_range
    max_trans_w = W * translation_range
    trans_d = torch.FloatTensor(B).uniform_(-max_trans_d, max_trans_d).to(device)
    trans_h = torch.FloatTensor(B).uniform_(-max_trans_h, max_trans_h).to(device)
    trans_w = torch.FloatTensor(B).uniform_(-max_trans_w, max_trans_w).to(device)
    translations = torch.stack((trans_d, trans_h, trans_w), dim=1)
    x_translated = shift_tensor(x_sheared, translations, device=device)

    x_noisy = x_translated

    transformed_input = x_noisy
    targets = x

    if args.transform_type == 'euler':
        params = torch.cat([angles_rad, translations], dim=1)

    elif args.transform_type == 'r9':
        r_mats = []
        for i in range(B):
            ax = angles_rad[i, 0].item()
            ay = angles_rad[i, 1].item()
            az = angles_rad[i, 2].item()
            Rb = compose_rotation(ax, ay, az, device)
            r_mats.append(Rb.reshape(1, 9))
        r9 = torch.cat(r_mats, dim=0)
        params = torch.cat([r9, translations], dim=1)

    elif args.transform_type == 'r6':
        r_mats = []
        for i in range(B):
            ax = angles_rad[i, 0].item()
            ay = angles_rad[i, 1].item()
            az = angles_rad[i, 2].item()
            Rb = compose_rotation(ax, ay, az, device)

            r6_once = matrix_to_r6(Rb.unsqueeze(0))            
            R_again = r6_to_matrix(r6_once)                    
            r6_final = matrix_to_r6(R_again)                   
        
            r_mats.append(r6_final)  

        r6 = torch.cat(r_mats, dim=0)
        params = torch.cat([r6, translations], dim=1)

    else:
        raise ValueError(f"Unsupported transform_type: {args.transform_type}")

    return transformed_input, targets, params

def test_original_transformation_consistency():
    device = torch.device("cpu")
    class Args: pass
    args = Args()
    args.transform_type = 'euler'
    x = torch.zeros((1,1,8,8,8), device=device)
    x[:,:,2:6,2:6,2:6] = 1.0
    out, target, params = apply_pretrain_transformations(x, device, 30, 0.2, 0.1, args)
    inv_angles = -params[:, :3]
    inv_trans = -params[:, 3:]
    xs = shear_tensor(out, torch.zeros(1,3), device)
    xt = shift_tensor(xs, inv_trans, device)
    xr = rotate_tensor(xt, inv_angles[:, 0], axes=(2,3), device=device)
    xr = rotate_tensor(xr, inv_angles[:, 1], axes=(2,4), device=device)
    xr = rotate_tensor(xr, inv_angles[:, 2], axes=(3,4), device=device)
    mse = ((xr - x)**2).mean().item()
    print("Original Reconstruction MSE:", mse)
    print("Params:", params)

test_original_transformation_consistency()

Original Reconstruction MSE: 0.024155516177415848
Params: tensor([[ 0.3228,  0.0587, -0.2962,  0.1256,  0.1109,  0.5333]])


### New utils

In [3]:
def rotate_tensor_around_center(tensor, angle_rad, axes, device):
    B, C, D, H, W = tensor.shape
    dtype = tensor.dtype
    cz = ((D/2 + 0.5)/D)*2 - 1
    cy = ((H/2 + 0.5)/H)*2 - 1
    cx = ((W/2 + 0.5)/W)*2 - 1
    offset = torch.tensor([cx, cy, cz], device=device, dtype=dtype)
    R = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).repeat(B,1,1)
    cos_a = torch.cos(angle_rad).type(dtype)
    sin_a = torch.sin(angle_rad).type(dtype)
    if axes == (2, 3):
        R[:,1,1] = cos_a
        R[:,1,2] = -sin_a
        R[:,2,1] = sin_a
        R[:,2,2] = cos_a
    elif axes == (2, 4):
        R[:,0,0] = cos_a
        R[:,0,2] = sin_a
        R[:,2,0] = -sin_a
        R[:,2,2] = cos_a
    elif axes == (3, 4):
        R[:,0,0] = cos_a
        R[:,0,1] = -sin_a
        R[:,1,0] = sin_a
        R[:,1,1] = cos_a
    A = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(B,1,1)
    A[:, :3, :3] = R
    A[:, 0, 3] = -offset[0]
    A[:, 1, 3] = -offset[1]
    A[:, 2, 3] = -offset[2]
    center_back = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(B,1,1)
    center_back[:, 0, 3] = offset[0]
    center_back[:, 1, 3] = offset[1]
    center_back[:, 2, 3] = offset[2]
    A = torch.bmm(center_back, A)
    M = A[:, :3, :]
    g = F.affine_grid(M, tensor.size(), align_corners=False)
    return F.grid_sample(tensor, g, mode='bilinear', padding_mode='border', align_corners=False)

def euler_rotate_around_center(x, angles_rad, device):
    B, C, D, H, W = x.shape
    out = x.clone()
    out = rotate_tensor_around_center(out, angles_rad[:, 0], axes=(2,3), device=device)
    out = rotate_tensor_around_center(out, angles_rad[:, 1], axes=(2,4), device=device)
    out = rotate_tensor_around_center(out, angles_rad[:, 2], axes=(3,4), device=device)
    return out

def axis_angle_to_matrix(ax, ay, az, device):
    sx = torch.sin(torch.tensor(ax, device=device))
    cx = torch.cos(torch.tensor(ax, device=device))
    sy = torch.sin(torch.tensor(ay, device=device))
    cy = torch.cos(torch.tensor(ay, device=device))
    sz = torch.sin(torch.tensor(az, device=device))
    cz = torch.cos(torch.tensor(az, device=device))
    Rx = torch.tensor([[1,0,0],[0,cx,-sx],[0,sx,cx]], device=device, dtype=torch.float32)
    Ry = torch.tensor([[cy,0,sy],[0,1,0],[-sy,0,cy]], device=device, dtype=torch.float32)
    Rz = torch.tensor([[cz,-sz,0],[sz,cz,0],[0,0,1]], device=device, dtype=torch.float32)
    return torch.matmul(Rz, torch.matmul(Ry, Rx))

def compose_rotation(ax, ay, az, device):
    Rx = axis_angle_to_matrix(ax, 0.0, 0.0, device)
    Ry = axis_angle_to_matrix(0.0, ay, 0.0, device)
    Rz = axis_angle_to_matrix(0.0, 0.0, az, device)
    R_approx = Rz @ (Ry @ Rx)
    U, _, V = torch.linalg.svd(R_approx)
    det_val = torch.linalg.det(torch.matmul(U, V.transpose(-2,-1)))
    if det_val < 0:
        V[..., -1] *= -1.0
    return torch.matmul(U, V.transpose(-2,-1))

def shear_tensor(x, shear_factors, device):
    B, C, D, H, W = x.shape
    outs = []
    for i in range(B):
        sx, sy, sz = shear_factors[i]
        A = torch.tensor([[1,sx,sy,0],[0,1,sz,0],[0,0,1,0]], device=device, dtype=torch.float32)
        g = F.affine_grid(A.unsqueeze(0), x[i:i+1].size(), align_corners=True)
        s = F.grid_sample(x[i:i+1], g, mode='bilinear', padding_mode='zeros', align_corners=True)
        outs.append(s)
    return torch.cat(outs, dim=0)

# def shift_tensor(tensor, shift, device):
#     B, C, D, H, W = tensor.size()
#     s = shift.clone()
#     s[:,0] = shift[:,0]*2.0/(W-1)
#     s[:,1] = shift[:,1]*2.0/(H-1)
#     s[:,2] = shift[:,2]*2.0/(D-1)
#     A = torch.eye(4, device=device).unsqueeze(0).repeat(B,1,1)
#     A[:, :3, 3] = -s
#     M = A[:, :3, :]
#     g = F.affine_grid(M, tensor.size(), align_corners=True)
#     return F.grid_sample(tensor, g, mode='nearest', padding_mode='border', align_corners=True)

def shift_tensor(tensor, shift, device):
    B, C, D, H, W = tensor.size()
    sx = shift[:, 2] * 2.0 / (W - 1)
    sy = shift[:, 1] * 2.0 / (H - 1)
    sz = shift[:, 0] * 2.0 / (D - 1)
    s_new = torch.stack([sx, sy, sz], dim=1)

    A = torch.eye(4, device=device).unsqueeze(0).repeat(B, 1, 1)
    A[:, :3, 3] = -s_new  
    M = A[:, :3, :]
    g = F.affine_grid(M, tensor.size(), align_corners=True)
    return F.grid_sample(tensor, g, mode='nearest', padding_mode='border', align_corners=True)

def matrix_to_r6(R):
    v1 = R[..., 0]
    v2 = R[..., 1]
    return torch.cat([v1, v2], dim=-1)

def r6_to_matrix(r6):
    if r6.size(-1) != 6:
        raise ValueError("r6_to_matrix expects last dimension=6.")
    e = 1e-7
    v1 = r6[..., :3]
    n1 = torch.norm(v1, dim=-1, keepdim=True)
    v1 = v1/(n1+e)
    v2 = r6[..., 3:]
    dot_v1 = torch.sum(v2*v1, dim=-1, keepdim=True)
    v2 = v2 - dot_v1*v1
    n2 = torch.norm(v2, dim=-1, keepdim=True)
    v2 = v2/(n2+e)
    v3 = torch.cross(v1, v2, dim=-1)
    R = torch.stack([v1,v2,v3], dim=-2)
    d = torch.linalg.det(R)
    v3 = torch.where(d.unsqueeze(-1)<0, -v3, v3)
    R = torch.stack([v1,v2,v3], dim=-2)
    return R

def apply_pretrain_transformations(x, device, max_rotation_angle, translation_range, shearing_range, args):
    B, C, D, H, W = x.shape
    angles_deg = torch.FloatTensor(B,3).uniform_(-max_rotation_angle, max_rotation_angle).to(device)
    angles_rad = angles_deg * math.pi/180.0
    shear_factors = torch.FloatTensor(B,3).uniform_(-shearing_range, shearing_range).to(device)
    # maxd, maxh, maxw = D*translation_range, H*translation_range, W*translation_range
    # td = torch.FloatTensor(B).uniform_(-maxw, maxw).to(device)
    # th = torch.FloatTensor(B).uniform_(-maxh, maxh).to(device)
    # tw = torch.FloatTensor(B).uniform_(-maxd, maxd).to(device)
    # translations = torch.stack((td,th,tw), dim=1)
    maxd = D * translation_range
    maxh = H * translation_range
    maxw = W * translation_range

    td = torch.FloatTensor(B).uniform_(-maxd, maxd).to(device)
    th = torch.FloatTensor(B).uniform_(-maxh, maxh).to(device)
    tw = torch.FloatTensor(B).uniform_(-maxw, maxw).to(device)
    
    translations = torch.stack((td, th, tw), dim=1)
    
    xs = shear_tensor(x, shear_factors, device)
    xt = shift_tensor(xs, translations, device)
    xr = euler_rotate_around_center(xt, angles_rad, device)
    if args.transform_type == 'euler':
        params = torch.cat([angles_rad, translations], dim=1)
    elif args.transform_type == 'r9':
        mats = []
        for i in range(B):
            ax, ay, az = angles_rad[i].tolist()
            R = compose_rotation(ax, ay, az, device)
            mats.append(R.view(1,9))
        r9 = torch.cat(mats, dim=0)
        params = torch.cat([r9, translations], dim=1)
    elif args.transform_type == 'r6':
        mats = []
        for i in range(B):
            ax, ay, az = angles_rad[i].tolist()
            Rb = compose_rotation(ax, ay, az, device)
            r6_once = matrix_to_r6(Rb.unsqueeze(0))
            R_again = r6_to_matrix(r6_once)
            r6_final = matrix_to_r6(R_again)
            mats.append(r6_final)
        r6 = torch.cat(mats, dim=0)
        params = torch.cat([r6, translations], dim=1)
    else:
        raise ValueError("Unsupported transform_type")
    return xr, x, params

def apply_finetune_transformations(x, device, max_rotation_angle, translation_range, shearing_range, args):
    B, C, D, H, W = x.shape
    angles_deg = torch.FloatTensor(B,3).uniform_(-max_rotation_angle, max_rotation_angle).to(device)
    angles_rad = angles_deg * math.pi/180.0
    shear_factors = torch.FloatTensor(B,3).uniform_(-shearing_range, shearing_range).to(device)
    # maxd, maxh, maxw = D*translation_range, H*translation_range, W*translation_range
    # td = torch.FloatTensor(B).uniform_(-maxw, maxw).to(device)
    # th = torch.FloatTensor(B).uniform_(-maxh, maxh).to(device)
    # tw = torch.FloatTensor(B).uniform_(-maxd, maxd).to(device)
    # translations = torch.stack((td,th,tw), dim=1)

    maxd = D * translation_range
    maxh = H * translation_range
    maxw = W * translation_range

    td = torch.FloatTensor(B).uniform_(-maxd, maxd).to(device)
    th = torch.FloatTensor(B).uniform_(-maxh, maxh).to(device)
    tw = torch.FloatTensor(B).uniform_(-maxw, maxw).to(device)
    
    translations = torch.stack((td, th, tw), dim=1)
    
    xs = shear_tensor(x, shear_factors, device)
    xt = shift_tensor(xs, translations, device)
    xr = euler_rotate_around_center(xt, angles_rad, device)
    if args.transform_type == 'euler':
        params = torch.cat([angles_rad, translations], dim=1)
    elif args.transform_type == 'r9':
        mats = []
        for i in range(B):
            ax, ay, az = angles_rad[i].tolist()
            R = compose_rotation(ax, ay, az, device)
            mats.append(R.view(1,9))
        r9 = torch.cat(mats, dim=0)
        params = torch.cat([r9, translations], dim=1)
    elif args.transform_type == 'r6':
        mats = []
        for i in range(B):
            ax, ay, az = angles_rad[i].tolist()
            Rb = compose_rotation(ax, ay, az, device)
            r6_once = matrix_to_r6(Rb.unsqueeze(0))
            R_again = r6_to_matrix(r6_once)
            r6_final = matrix_to_r6(R_again)
            mats.append(r6_final)
        r6 = torch.cat(mats, dim=0)
        params = torch.cat([r6, translations], dim=1)
    else:
        raise ValueError("Unsupported transform_type")
    return xr, x, params

def apply_testset_transformations(x, device, max_rotation_angle, translation_range, args):
    B, C, D, H, W = x.shape
    angles_deg = torch.FloatTensor(B,3).uniform_(-max_rotation_angle, max_rotation_angle).to(device)
    angles_rad = angles_deg * math.pi/180.0
    # maxd, maxh, maxw = D*translation_range, H*translation_range, W*translation_range
    # td = torch.FloatTensor(B).uniform_(-maxw, maxw).to(device)
    # th = torch.FloatTensor(B).uniform_(-maxh, maxh).to(device)
    # tw = torch.FloatTensor(B).uniform_(-maxd, maxd).to(device)
    # translations = torch.stack((td,th,tw), dim=1)

    maxd = D * translation_range
    maxh = H * translation_range
    maxw = W * translation_range

    td = torch.FloatTensor(B).uniform_(-maxd, maxd).to(device)
    th = torch.FloatTensor(B).uniform_(-maxh, maxh).to(device)
    tw = torch.FloatTensor(B).uniform_(-maxw, maxw).to(device)
    
    translations = torch.stack((td, th, tw), dim=1)
    
    xr = euler_rotate_around_center(x, angles_rad, device)
    xt = shift_tensor(xr, translations, device)
    if args.transform_type == 'euler':
        params = torch.cat([angles_rad, translations], dim=1)
    elif args.transform_type == 'r9':
        mats = []
        for i in range(B):
            ax, ay, az = angles_rad[i].tolist()
            R = compose_rotation(ax, ay, az, device)
            mats.append(R.view(1,9))
        r9 = torch.cat(mats, dim=0)
        params = torch.cat([r9, translations], dim=1)
    elif args.transform_type == 'r6':
        mats = []
        for i in range(B):
            ax, ay, az = angles_rad[i].tolist()
            Rb = compose_rotation(ax, ay, az, device)
            r6_once = matrix_to_r6(Rb.unsqueeze(0))
            R_again = r6_to_matrix(r6_once)
            r6_final = matrix_to_r6(R_again)
            mats.append(r6_final)
        r6 = torch.cat(mats, dim=0)
        params = torch.cat([r6, translations], dim=1)
    else:
        raise ValueError("Unsupported transform_type")
    return xt, x, params

def euler_rotate_around_center(x, angles_rad, device):
    B, C, D, H, W = x.shape
    out = x.clone()
    out = rotate_tensor_around_center(out, angles_rad[:, 0], axes=(2,3), device=device)
    out = rotate_tensor_around_center(out, angles_rad[:, 1], axes=(2,4), device=device)
    out = rotate_tensor_around_center(out, angles_rad[:, 2], axes=(3,4), device=device)
    return out

def test_transformation_consistency():
    device = torch.device("cpu")
    class Args: pass
    args = Args()
    args.transform_type = 'euler'
    x = torch.zeros((1,1,8,8,8), device=device)
    x[:,:,2:6,2:6,2:6] = 1.0
    out, target, params = apply_pretrain_transformations(x, device, 30, 0.2, 0.1, args)
    inv_angles = -params[:, :3]
    inv_trans = -params[:, 3:]
    xs = shear_tensor(out, torch.zeros(1,3), device)
    xt = shift_tensor(xs, inv_trans, device)
    rec = euler_rotate_around_center(xt, inv_angles, device)
    mse = ((rec - x)**2).mean().item()
    print("Reconstruction MSE:", mse)
    print("Params:", params)

test_transformation_consistency()

Reconstruction MSE: 0.018647510558366776
Params: tensor([[-0.1707, -0.0163, -0.3261, -1.3201,  0.3330, -0.2923]])
