In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

def gram_matrix(mat):
    G = mat@mat.T    
    return G

style_mat = torch.tensor([
    [2.0,  3.0],
    [-1.0, 4.0]
])
style_mat.requires_grad_(False)
G_style = gram_matrix(style_mat)
print("style_mat:\n", style_mat)
print("G_style:\n", G_style)

input_mat = nn.Parameter(torch.randn(2, 3))
print("input_mat:\n", input_mat)

style_mat:
 tensor([[ 2.,  3.],
        [-1.,  4.]])
G_style:
 tensor([[13., 10.],
        [10., 17.]])
input_mat:
 Parameter containing:
tensor([[-1.2328e-01,  1.0980e+00,  1.4641e+00],
        [ 1.1249e-03, -2.9110e-01, -2.4231e+00]], requires_grad=True)


In [2]:
criterion = nn.MSELoss()
optimizer = optim.Adam([input_mat], lr=0.01)

num_epochs = 2000
for epoch in range(num_epochs):
    optimizer.zero_grad()

    G_input = gram_matrix(input_mat)
    loss = criterion(G_input, G_style)

    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss = {loss.item():.6f}")

print("style_mat:\n", style_mat)
print("input_mat:\n", input_mat)

Epoch [100/2000], Loss = 40.950680
Epoch [200/2000], Loss = 0.712595
Epoch [300/2000], Loss = 0.071621
Epoch [400/2000], Loss = 0.007151
Epoch [500/2000], Loss = 0.000528
Epoch [600/2000], Loss = 0.000027
Epoch [700/2000], Loss = 0.000001
Epoch [800/2000], Loss = 0.000000
Epoch [900/2000], Loss = 0.000000
Epoch [1000/2000], Loss = 0.000000
Epoch [1100/2000], Loss = 0.000000
Epoch [1200/2000], Loss = 0.000000
Epoch [1300/2000], Loss = 0.000000
Epoch [1400/2000], Loss = 0.000000
Epoch [1500/2000], Loss = 0.000000
Epoch [1600/2000], Loss = 0.000000
Epoch [1700/2000], Loss = 0.000000
Epoch [1800/2000], Loss = 0.000000
Epoch [1900/2000], Loss = 0.000000
Epoch [2000/2000], Loss = 0.000000
style_mat:
 tensor([[ 2.,  3.],
        [-1.,  4.]])
input_mat:
 Parameter containing:
tensor([[-2.1855,  2.8429, -0.3762],
        [-2.0084,  1.5430, -3.2535]], requires_grad=True)


In [3]:
gram_matrix(style_mat)

tensor([[13., 10.],
        [10., 17.]])

In [4]:
gram_matrix(input_mat)

tensor([[13.0000, 10.0000],
        [10.0000, 17.0000]], grad_fn=<MmBackward0>)