In [11]:
import torch
import torch.nn.functional as F


sobel_dx = torch.tensor([[-1, 0, 1],
                         [-2, 0, 2],
                         [-1, 0, 1]], dtype=torch.float32)

sobel_dy = torch.tensor([[-1, -2, -1],
                         [ 0,  0,  0],
                         [ 1,  2,  1]], dtype=torch.float32)

kernel = torch.stack([sobel_dx, sobel_dy])   # [2,3,3]
kernel = kernel.unsqueeze(1).repeat(1, 3, 1, 1)  # [2,3,3,3]

def sobel_filter(img: torch.Tensor) -> torch.Tensor:
    """
    img: Nx3xHxW float32 in [0,1] or [0,255]
    returns: Nx2xHxW  (channel 0 = ∂I/∂x, channel 1 = ∂I/∂y)
    """
    return F.conv2d(img, kernel, padding=1)

def sobel_magnitude(img: torch.Tensor) -> torch.Tensor:
    g = sobel_filter(img)
    return (g ** 2).sum(1, keepdim=True).sqrt()



In [12]:
img = torch.rand(1, 3, 1428, 1904)
sobel_filter(img).shape

torch.Size([1, 2, 1428, 1904])

In [27]:
class Sobel(torch.nn.Module):
    def __init__(self):
        super(Sobel, self).__init__()
        sobel_dx = torch.tensor([[-1, 0, 1],
                                [-2, 0, 2],
                                [-1, 0, 1]], dtype=torch.float32)

        sobel_dy = torch.tensor([[-1, -2, -1],
                                [ 0,  0,  0],
                                [ 1,  2,  1]], dtype=torch.float32)

        sobel_kernel = torch.stack([sobel_dx, sobel_dy])   # [2,3,3]
        sobel_kernel = sobel_kernel.unsqueeze(1).repeat(1, 3, 1, 1)  # [2,3,3,3]
        sobel_kernel = sobel_kernel.to(torch.float32)

        self.sobel_kernel = torch.nn.Parameter(sobel_kernel, requires_grad=False)
        # self.sobel_cnn = torch.nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False).to(torch.float16)
        # self.sobel_cnn.weight = torch.nn.Parameter(sobel_kernel, requires_grad=False)

    def forward(self, x):
        # return self.sobel_cnn(x)
        return F.conv2d(x, self.sobel_kernel, padding=1,stride=1)

sobel = Sobel().to('mps').to(torch.float32)
sm = torch.jit.script(sobel)
sm.save("models/sobel_float32.pt")
sobel = Sobel().to('mps').to(torch.float16)
sm = torch.jit.script(sobel)
sm.save("models/sobel_float16.pt")

In [18]:
m = torch.jit.load("models/sobel.pt")

In [21]:
img = torch.rand(1, 3, 1428, 1904).to(torch.float16)
existing_model_output = sobel(img)

In [None]:
loaded_model_output = m(img)
print(torch.allclose(existing_model_output, loaded_model_output, atol=1e-5))
print(existing_model_output.shape)
print(loaded_model_output.shape)

True
torch.Size([1, 2, 1428, 1904])
torch.Size([1, 2, 1428, 1904])


In [None]:
import cv2

# Open the default camera
cam = cv2.VideoCapture(0)

# Get the default frame width and height
frame_width = int(cam.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cam.get(cv2.CAP_PROP_FRAME_HEIGHT))

# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output.mp4', fourcc, 20.0, (frame_width, frame_height))

while True:
    ret, frame = cam.read()

    # Write the frame to the output file
    out.write(frame)

    # Display the captured frame
    cv2.imshow('Camera', frame)

    # Press 'q' to exit the loop
    if cv2.waitKey(1) == ord('q'):
        break

# Release the capture and writer objects
cam.release()
out.release()
cv2.destroyAllWindows()



: 