In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3)

    def forward(self, x):
        return self.conv(x)

# Initialize model and optimizer
model = SimpleCNN()
optimizer = optim.Adam(model.parameters(), lr=0.1)

# Dummy input and target
input_tensor = torch.randn(1, 3, 32, 32)
target = torch.randn(1, 10, 30, 30)

# Forward pass
output = model(input_tensor)
loss = nn.MSELoss()(output, target)
loss.backward()
optimizer.step()

for param in model.parameters():
    print("++ para 1", param.grad.shape)
    print("++ para 1", param.grad.flatten()[0])
    
for param in model.parameters():
    print("++ para 2", param.grad.shape)
    print("++ para 2", param.grad.flatten()[0])

print("\n", model.conv.weight[0])

# Replace a filter in the convolution layer
with torch.no_grad():
    new_filter = torch.ones_like(model.conv.weight[0])  # Generate a new filter with the same shape
    model.conv.weight = torch.nn.Parameter(torch.ones_like(model.conv.weight)) # new_filter)  # In-place replacement

print("\n", model.conv.weight[0])

# optimizer = optim.Adam(model.parameters(), lr=0.1) # we need to reset the optimiser!!

optimizer = optim.Adam(model.parameters(), lr=0.1)

# Ensure the optimizer still tracks the parameters
output = model(input_tensor)
loss = nn.MSELoss()(output, target)
loss.backward()
optimizer.step()

for param in model.parameters():
    print("++ para 3", param.grad.shape)
    print("++ para 3", param.grad.flatten()[0])

print("\n", model.conv.weight[0])

input_tensor = torch.randn(1, 3, 32, 32)
target = torch.randn(1, 10, 30, 30)

output = model(input_tensor)
loss = nn.MSELoss()(output, target)
loss.backward()
optimizer.step()

print("\n", model.conv.weight[0])

input_tensor = torch.randn(1, 3, 32, 32)
target = torch.randn(1, 10, 30, 30)

output = model(input_tensor)
loss = nn.MSELoss()(output, target)
loss.backward()
optimizer.step()

print("\n", model.conv.weight[0])


print("Filter replaced and updated successfully!")

++ para 1 torch.Size([10, 3, 3, 3])
++ para 1 tensor(0.0160)
++ para 1 torch.Size([10])
++ para 1 tensor(0.0115)
++ para 2 torch.Size([10, 3, 3, 3])
++ para 2 tensor(0.0160)
++ para 2 torch.Size([10])
++ para 2 tensor(0.0115)

 tensor([[[ 0.0247, -0.0255,  0.0839],
         [ 0.0488,  0.0410,  0.0630],
         [ 0.0259, -0.0305,  0.0154]],

        [[ 0.0643, -0.0218,  0.0294],
         [ 0.0943, -0.0487,  0.0184],
         [ 0.1097,  0.1230, -0.0317]],

        [[-0.0063,  0.0250, -0.0871],
         [ 0.0921,  0.0216, -0.0016],
         [-0.0241,  0.0095, -0.0833]]], grad_fn=<SelectBackward0>)

 tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]], grad_fn=<SelectBackward0>)
++ para 3 torch.Size([10, 3, 3, 3])
++ para 3 tensor(0.2253)
++ para 3 torch.Size([10])
++ para 3 tensor(0.2025)

 tensor([[[0.9000, 0.9000, 0.9000],
     

In [2]:
model

SimpleCNN(
  (conv): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
)