In [3]:
import torch
import torch.nn.functional as F

# Detect device
device = torch.device ("cpu")
print(f"🚀 Using device: {device}")

# Gaussian Kernel
def gaussian_kernel(size, sigma, channels):
    coords = torch.arange(size).float() - size / 2 + 0.5
    gauss = torch.exp(-(coords**2)/(2*sigma**2))
    gauss /= gauss.sum()
    kernel_1d = gauss.unsqueeze(1)
    kernel_2d = torch.matmul(kernel_1d, kernel_1d.t()).unsqueeze(0).unsqueeze(0)
    return kernel_2d.expand(channels,1,size,size).contiguous()

SSIM_KERNEL = gaussian_kernel(11, 1.5, 3)

# SSIM Function
def ssim_torch(img1, img2, kernel=SSIM_KERNEL):
    kernel = kernel.to(img1.device)
    C1, C2 = 0.01**2, 0.03**2
    def conv(img): return F.conv2d(img, kernel, groups=3, padding=5)
    mu1, mu2 = conv(img1), conv(img2)
    mu1_sq, mu2_sq, mu1_mu2 = mu1**2, mu2**2, mu1*mu2
    sigma1_sq = conv(img1*img1) - mu1_sq
    sigma2_sq = conv(img2*img2) - mu2_sq
    sigma12 = conv(img1*img2) - mu1_mu2
    numerator = (2*mu1_mu2 + C1)*(2*sigma12 + C2)
    denominator = (mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)
    return (numerator/denominator).mean().item()


🚀 Using device: cpu
