# Factorization Machines


目标函数：

$$\begin{align}\hat y(x) &= w_0 + \sum_{i=1}^{d} w_i x_i + \sum_{i=1}^{d}\sum_{j=i+1}^{d} \langle v_i,v_j \rangle x_i x_j \\
&= w_0 + \sum_{i=1}^{d} w_i x_i+ \frac{1}{2} \sum_{l=1}^{k}(\sum_{i=1}^{d} v_{il} x_i)^{2}-(\sum_{i=1}^{d} v_{il}^{2} x_i^{2}))
\end{align}$$

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as Data
from sklearn.model_selection import train_test_split

In [99]:
class FM(nn.Module):
    def __init__(self, input_size, num_factors):
        super(FM, self).__init__()
        self.V = nn.Embedding(num_embeddings=input_size, embedding_dim=num_factors)
        self.X = nn.Embedding(num_embeddings=input_size, embedding_dim=1)
        self.linear = nn.Linear(in_features=1, out_features=1, bias=True)

    def forward(self, x):
        v = self.V(x)
        s = 0.5*((v.sum(dim=1))**2 - (v**2).sum(dim=1)).sum(dim=1, keepdim=True) + self.X(x).sum(dim=1)
        y = torch.sigmoid(self.linear(s))
        return y

In [100]:
input_size = 10000
num_factors = 10

In [101]:
model = FM(input_size,num_factors)

In [102]:
print(model)

FM(
  (V): Embedding(10000, 10)
  (X): Embedding(10000, 1)
  (linear): Linear(in_features=1, out_features=1, bias=True)
)


In [103]:
model(torch.tensor([[9,9],[10,10]]))

tensor([[0.9991],
        [0.9973]], grad_fn=<SigmoidBackward>)

In [17]:
a = torch.tensor([[9,9],[10,10]])

In [18]:
help(a.sum)

Help on built-in function sum:

sum(...) method of torch.Tensor instance
    sum(dim=None, keepdim=False, dtype=None) -> Tensor
    
    See :func:`torch.sum`



In [94]:
(a.sum(1))**2

tensor([324, 400])

In [88]:
(a**2).sum(-1, keepdim=True)

tensor([[162],
        [200]])

In [72]:
en = nn.Embedding(num_embeddings=input_size, embedding_dim=num_factors)

In [73]:
b = en(a)

In [96]:
b.shape

torch.Size([2, 2, 10])

In [97]:
b.sum(dim=1).shape

torch.Size([2, 10])

In [85]:
X = nn.Embedding(num_embeddings=input_size, embedding_dim=1)

In [86]:
c = X(a)

In [87]:
c.shape

torch.Size([2, 2, 1])