This is an experiment where I try to create my own partial convolution 2d block

## Exercise 1: Implement `__init__`

```python
class PartialConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
        super().__init__()
        
        # TODO: Create these yourself
        # 1. self.conv = ???
        # 2. self.bias = ???
        # 3. self.register_buffer('ones_kernel', ???)
        # 4. self.window_size = ???
```

Write this out by hand. Don't copy-paste from your old code.

---

## Exercise 2: Implement `forward` (One Step at a Time)

Start with just Step 1:

```python
def forward(self, x, mask):
    # Step 1: Mask the input
    x_masked = x * mask
    
    print(f"Input shape: {x.shape}")
    print(f"Mask shape: {mask.shape}")
    print(f"Masked input shape: {x_masked.shape}")
    
    # For now, just return dummy values
    return x_masked, mask
```

**Test it**:
```python
pconv = PartialConv2d(3, 64, kernel_size=3, padding=1)
x = torch.randn(1, 3, 32, 32)
mask = torch.ones(1, 3, 32, 32)
mask[:, :, 10:20, 10:20] = 0  # Create a hole

out, new_mask = pconv(x, mask)
print(out.shape)  # Should be [1, 64, 32, 32]
```

Once Step 1 works, add Step 2, test again, etc.

---

## Exercise 3: Test Mask Propagation

Create a visual test:

```python
# Create input with big hole in center
x = torch.randn(1, 1, 32, 32)
mask = torch.ones(1, 1, 32, 32)
mask[:, :, 12:20, 12:20] = 0  # 8×8 hole

# Apply partial conv once
pconv = PartialConv2d(1, 1, kernel_size=3, padding=1)
out1, mask1 = pconv(x, mask)

# Visualize
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3)
axes[0].imshow(mask[0,0].numpy(), cmap='gray')
axes[0].set_title('Original Mask')
axes[1].imshow(mask1[0,0].numpy(), cmap='gray')
axes[1].set_title('After 1 Layer')

# Apply again
out2, mask2 = pconv(out1, mask1)
axes[2].imshow(mask2[0,0].numpy(), cmap='gray')
axes[2].set_title('After 2 Layers')
plt.show()
```

**Expected result**: The hole should **shrink** by ~1 pixel on each edge per layer (for kernel_size=3).

---

# Critical Debugging Questions

When you implement this, ask yourself:

1. **Does the mask shrink correctly?** 
   - After 1 layer with 3×3 kernel, an 8×8 hole should become ~6×6

2. **Do the output values make sense?**
   - In valid regions, should match regular conv (roughly)
   - In filled regions, should be blend of nearby valid pixels

3. **What happens at boundaries?**
   - Edge pixels have fewer neighbors - scaling factor handles this

4. **Does it work with stride > 1?**
   - Mask should downsample too

5. **Are gradients flowing?**
   - Run backward pass, check if `conv.weight.grad` exists

6. **Memory issues?**
   - ones_kernel might be huge - use slicing if needed

In [None]:
#start here
import torch 
import torch.nn as nn
import torch.nn.functional as F
class PartialConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
        super().__init__()
        
        # TODO: Create these yourself
        # 1. self.conv = ???
        self.conv = nn.conv2d(input=in_channels, weight= out_channels, bias=None,stride=kernel_size, padding=padding)
        # 2. self.bias = ???
        self.bias = nn.Parameter(input=in_channels)
        # 3. self.register_buffer('ones_kernel', ???)
        self.register_buffer = 
        # 4. self.window_size = ???

