In [None]:
import torch 
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import matplotlib.pylab as plt
import numpy as np
from PIL import Image

CSV_FILE = "./dataset/Cards-Image/cards.csv"
DATA_DIR = './dataset/Cards-Image/'

import pandas as pd
df = pd.read_csv(CSV_FILE)
queen_of_hearts = df[df.labels == "queen of hearts"].iloc[0]
ace_of_spades = df[df.labels == "ace of spades"].iloc[0]
print(f'{queen_of_hearts=} \n  {ace_of_spades=}')

IMG_SIZE = 128
transform_image = transforms.Compose([
    #transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0., 0., 0.), (1., 1., 1.))
])

img_path = DATA_DIR + queen_of_hearts.filepaths
imgQH = Image.open(img_path)
imgQH = transform_image(imgQH)
img_path = DATA_DIR + ace_of_spades.filepaths
imgAS = Image.open(img_path)
imgAS = transform_image(imgAS)

h = plt.figure(figsize=(10,4))
h.add_subplot(1,2,1)
plt.imshow(imgQH.permute(1,2,0))
plt.axis("off")
h.add_subplot(1,2,2)
plt.imshow(imgAS.permute(1,2,0))
plt.axis("off")
plt.show()

imgTensor = torch.stack((imgQH, imgAS), dim=0)

conv2d = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1, bias=False, padding_mode="replicate")
state_dict = conv2d.state_dict()
print(f'{state_dict=}')
print(state_dict['weight'].shape)

sobelGx = torch.tensor([[1, 0 , -1],
                        [2, 0, -2],
                        [1, 0 , -1]], requires_grad=False, dtype=torch.float32)

sobelGy = torch.tensor([[1, 2 , 1],
                        [0, 0, 0],
                        [-1, -2 , -1]], requires_grad=False, dtype=torch.float32)

sobelGx = sobelGx / 4.0
sobelGy = sobelGy / 4.0

filterQH_Gx = torch.stack((sobelGx, torch.zeros((3,3)), torch.zeros((3,3))), dim=0)
filterQH_Gy = torch.stack((sobelGy, torch.zeros((3,3)), torch.zeros((3,3))), dim=0)
filterAS_Gx = torch.stack((sobelGx/3, sobelGx/3, sobelGx/3), dim=0)
filterAS_Gy = torch.stack((sobelGy/3, sobelGy/3, sobelGy/3), dim=0)

print(f'{filterQH_Gy=}')
weightTensor = torch.stack((filterQH_Gx, filterQH_Gy, filterAS_Gx, filterAS_Gy ), dim=0)
print(f'{weightTensor=} {weightTensor.shape=}')

state_dict = conv2d.state_dict()
state_dict["weight"] = weightTensor
conv2d.load_state_dict(state_dict)

result = conv2d(imgTensor)
print(result.shape)

h = plt.figure(figsize=(14,10))
k = 0
for i in range(0, result.shape[0]):
    for j in range(0, result.shape[1]):
        k+=1
        min = torch.min(result[i][j])
        max = torch.max(result[i][j])
        print(f'{min=} {max=}')
        h.add_subplot(result.shape[0],result.shape[1],k)
        print(result[i][j].shape)
        plt.imshow(result[i][j].detach().numpy(), vmin=-1, vmax=1, cmap='gray')
        plt.axis("off")
plt.show()