In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

def gram_matrix(mat):
    G = mat@mat.T    
    return G

style_mat = torch.tensor([
    [2.0,  3.0],
    [-1.0, 3.5]
])
style_mat.requires_grad_(False)
print("style_mat:\n", style_mat)

G_style = gram_matrix(style_mat)
print("G_style:\n", G_style)

input_mat = nn.Parameter(torch.randn(2, 2))
print("input_mat:\n", input_mat)

style_mat:
 tensor([[ 2.0000,  3.0000],
        [-1.0000,  3.5000]])
G_style:
 tensor([[13.0000,  8.5000],
        [ 8.5000, 13.2500]])
input_mat:
 Parameter containing:
tensor([[-0.4368, -0.0517],
        [ 1.0422,  0.0646]], requires_grad=True)


In [2]:
criterion = nn.MSELoss()
optimizer = optim.Adam([input_mat], lr=0.01)

num_epochs = 2000
for epoch in range(num_epochs):
    optimizer.zero_grad()

    G_input = gram_matrix(input_mat)
    loss = criterion(G_input, G_style)

    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss = {loss.item():.6f}")

print("style_mat:\n", style_mat)
print("input_mat:\n", input_mat)

Epoch [100/2000], Loss = 68.361206
Epoch [200/2000], Loss = 1.625280
Epoch [300/2000], Loss = 0.412603
Epoch [400/2000], Loss = 0.063498
Epoch [500/2000], Loss = 0.004668
Epoch [600/2000], Loss = 0.000185
Epoch [700/2000], Loss = 0.000004
Epoch [800/2000], Loss = 0.000000
Epoch [900/2000], Loss = 0.000000
Epoch [1000/2000], Loss = 0.000000
Epoch [1100/2000], Loss = 0.000000
Epoch [1200/2000], Loss = 0.000000
Epoch [1300/2000], Loss = 0.000000
Epoch [1400/2000], Loss = 0.000000
Epoch [1500/2000], Loss = 0.000000
Epoch [1600/2000], Loss = 0.000000
Epoch [1700/2000], Loss = 0.000000
Epoch [1800/2000], Loss = 0.000000
Epoch [1900/2000], Loss = 0.000000
Epoch [2000/2000], Loss = 0.000000
style_mat:
 tensor([[ 2.0000,  3.0000],
        [-1.0000,  3.5000]])
input_mat:
 Parameter containing:
tensor([[ 1.5889, -3.2366],
        [ 3.5286, -0.8940]], requires_grad=True)


In [3]:
gram_matrix(style_mat)

tensor([[13.0000,  8.5000],
        [ 8.5000, 13.2500]])

In [4]:
gram_matrix(input_mat)

tensor([[13.0000,  8.5000],
        [ 8.5000, 13.2500]], grad_fn=<MmBackward0>)