In [29]:
import torch
import torch.nn.functional as F

In [30]:
torch.manual_seed(42)

<torch._C.Generator at 0x2807f7d6470>

# Input

In [31]:
# Content image (ma trận màu xanh dương):
# [ [3, 2, 4],
#   [1, 9, 5],
#   [8, 6, 7] ]
content_img = torch.tensor(
    [[[[3.0, 2.0, 4.0],
       [1.0, 9.0, 5.0],
       [8.0, 6.0, 7.0]]]],
    requires_grad=False
)

# Style image (ma trận màu xanh lá):
# [ [1, 2, 4],
#   [1, 3, 4],
#   [4, 2, 1] ]
style_img = torch.tensor(
    [[[[1.0, 2.0, 4.0],
       [1.0, 3.0, 4.0],
       [4.0, 2.0, 1.0]]]],
    requires_grad=False
)

# Output image khởi tạo toàn 0 (ma trận màu xám) với requires_grad=True để có thể tối ưu
output_img = torch.zeros_like(content_img, requires_grad=True)

In [32]:
print("=== INPUT IMAGES ===")
print("Content Image:\n", content_img[0,0])
print("Style Image:\n", style_img[0,0])
print("Output Image (init):\n", output_img.detach()[0,0])

=== INPUT IMAGES ===
Content Image:
 tensor([[3., 2., 4.],
        [1., 9., 5.],
        [8., 6., 7.]])
Style Image:
 tensor([[1., 2., 4.],
        [1., 3., 4.],
        [4., 2., 1.]])
Output Image (init):
 tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])


# Model

In [33]:
# 1 input channel, 1 output channel, kernel size = 2x2
conv_weight = torch.tensor([[[[1.0, -1.0],
                              [1.0, 0.0]]]], requires_grad=False)  # shape [1, 1, 2, 2]

In [34]:
def simple_cnn(x, weight):
    return F.conv2d(x, weight, stride=1, padding=0)

In [35]:
# Extract feature maps
F_content = simple_cnn(content_img, conv_weight)
F_style = simple_cnn(style_img, conv_weight)
F_output = simple_cnn(output_img, conv_weight)

In [36]:
print("\n=== FEATURE MAPS ===")
print("F_content:\n", F_content[0,0])
print("F_style:\n", F_style[0,0])
print("F_output:\n", F_output[0,0])


=== FEATURE MAPS ===
F_content:
 tensor([[ 2.,  7.],
        [ 0., 10.]])
F_style:
 tensor([[0., 1.],
        [2., 1.]])
F_output:
 tensor([[0., 0.],
        [0., 0.]], grad_fn=<SelectBackward0>)


# Loss and Gram matrix

In [37]:
def gram_matrix(x):
    b, c, h, w = x.shape
    features = x.view(c, h * w)
    return torch.mm(features, features.t())  # shape: [c, c]

In [38]:
def content_loss(F_target, F_content):
    return F.mse_loss(F_target, F_content)

def style_loss(F_target, F_style):
    G_target = gram_matrix(F_target)
    G_style = gram_matrix(F_style)
    return F.mse_loss(G_target, G_style)

In [39]:
# Tính loss
c_loss = content_loss(F_output, F_content)
s_loss = style_loss(F_output, F_style)
total_loss = c_loss + s_loss

In [40]:
print("\n=== LOSSES ===")
print(f"Content Loss: {c_loss.item():.4f}")
print(f"Style Loss: {s_loss.item():.4f}")
print(f"Total Loss: {total_loss.item():.4f}")


=== LOSSES ===
Content Loss: 38.2500
Style Loss: 36.0000
Total Loss: 74.2500


In [41]:
# Backward
total_loss.backward()

# Learning rate
lr = 0.01
with torch.no_grad():
    output_img -= lr * output_img.grad
    output_img.grad.zero_()

In [42]:
print("\n=== GRADIENTS ===")
print("Grad of output_img:\n", output_img.grad[0,0])

print("\n=== OUTPUT IMAGE AFTER 1 UPDATE ===")
print(output_img.detach()[0,0])


=== GRADIENTS ===
Grad of output_img:
 tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

=== OUTPUT IMAGE AFTER 1 UPDATE ===
tensor([[ 0.0100,  0.0250, -0.0350],
        [ 0.0100,  0.0850, -0.0500],
        [ 0.0000,  0.0500,  0.0000]])
