# Color Inversion

In [1]:
import torch

def solve(image: list[int], width: int, height: int):
    # Convert list to torch tensor
    tensor = torch.tensor(image, dtype=torch.uint8)

    # Reshape to (num_pixels, 4)
    tensor = tensor.view(-1, 4)

    # Invert RGB (index 0~2), keep Alpha (index 3)
    rgb = tensor[:, :3]
    alpha = tensor[:, 3]

    inverted_rgb = 255 - rgb

    # Concatenate inverted RGB and original Alpha
    result = torch.cat((inverted_rgb, alpha.unsqueeze(1)), dim=1)

    # Flatten and write back to original image list
    image[:] = result.flatten().tolist()

In [None]:
# Example 1
img = [255, 0, 128, 255, 0, 255, 0, 255]
solve(img, width=1, height=2)
print(img)  # ✅ [0, 255, 127, 255, 255, 0, 255, 255]

# Example 2
img = [10, 20, 30, 255, 100, 150, 200, 255]
solve(img, width=2, height=1)
print(img)  # ✅ [245, 235, 225, 255, 155, 105, 55, 255]