In [27]:
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import tqdm

In [28]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')


In [29]:


class MovieLensDataset(Dataset):
    def __init__(self, path):
        data = pd.read_csv(path, sep="\t").values
        self.items = data[:, :2].astype(np.int32) - 1  # -1 because ID begins from 1
        self.targets = self.__preprocess_target(data[:, 2]).astype(np.float32)
        self.field_dims = np.max(self.items, axis=0) + 1
        self.user_field_idx = np.array((0, ), dtype=np.int32)
        self.item_field_idx = np.array((1,), dtype=np.int32)

    def __len__(self):
        return self.items.shape[0]

    def __getitem__(self, index):
        return self.items[index], self.targets[index]

    def __preprocess_target(self, target):
        # return target
        target[target <= 3] = 0
        target[target > 3] = 1
        return target

In [30]:
class FM(nn.Module):
    def __init__(self, field_dims, dim = 10):
        super().__init__()
        self.embedding = torch.nn.Embedding(sum(field_dims), dim)
        self.fc = torch.nn.Embedding(sum(field_dims), 1)
        self.bias = torch.nn.Parameter(torch.zeros((1,)))

    def forward(self, x):
        square_of_sum = torch.sum(self.embedding(x), dim=1) ** 2
        sum_of_square = torch.sum(self.embedding(x) ** 2, dim=1)
        ix = 0.5 * (square_of_sum - sum_of_square)
        ix = torch.sum(ix, dim=1, keepdim=True)
        x = self.bias + torch.sum(self.fc(x), dim=1) + ix
        return torch.sigmoid(x.squeeze(1))

In [31]:
def train(model, optimizer, data_loader, criterion, device, log_interval=100):
    model.train()
    total_loss = 0
    tk0 = tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0)
    for i, (fields, target) in enumerate(tk0):
        fields, target = fields.to(device), target.to(device)
        y = model(fields)
        loss = criterion(y, target)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i + 1) % log_interval == 0:
            tk0.set_postfix(loss=total_loss / log_interval)
            total_loss = 0

def test(model, data_loader, device):
    model.eval()
    targets, predicts = [], []
    with torch.no_grad():
        for i, (x, y) in enumerate(data_loader):
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            targets.extend(y.tolist())
            predicts.extend(y_hat.tolist())
    return roc_auc_score(targets, predicts)

In [32]:
dataset = MovieLensDataset("../data/ml-100k/u.data")

train_length = int(len(dataset) * 0.8)
valid_length = int(len(dataset) * 0.1)
test_length = len(dataset) - train_length - valid_length

train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
    dataset, (train_length, valid_length, test_length)
)

train_data_loader = DataLoader(train_dataset, batch_size=32)
valid_data_loader = DataLoader(valid_dataset, batch_size=32)
test_data_loader = DataLoader(test_dataset, batch_size=32)

field_dims = dataset.field_dims

In [34]:
model = FM(field_dims, 16).to(device)
print(model.__class__.__name__)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01, weight_decay=1e-5)

for epoch_i in range(10):
    train(model, optimizer, train_data_loader, criterion, device)
    auc = test(model, valid_data_loader, device)
test_auc = test(model, test_data_loader, device)
print("test auc:", test_auc)

FM


  0%|          | 0/2500 [00:00<?, ?it/s]

100%|██████████| 2500/2500 [00:30<00:00, 81.77it/s, loss=0.687]
100%|██████████| 2500/2500 [00:16<00:00, 150.08it/s, loss=0.611]
100%|██████████| 2500/2500 [00:17<00:00, 144.25it/s, loss=0.583]
100%|██████████| 2500/2500 [00:17<00:00, 145.39it/s, loss=0.565]
100%|██████████| 2500/2500 [00:16<00:00, 147.06it/s, loss=0.555]
100%|██████████| 2500/2500 [00:17<00:00, 140.11it/s, loss=0.549]
100%|██████████| 2500/2500 [00:17<00:00, 143.65it/s, loss=0.544]
100%|██████████| 2500/2500 [00:18<00:00, 137.86it/s, loss=0.541]
100%|██████████| 2500/2500 [00:17<00:00, 143.21it/s, loss=0.539]
100%|██████████| 2500/2500 [00:16<00:00, 147.49it/s, loss=0.538]


test auc: 0.6722590109376825
