In [1]:
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()

In [2]:
from fastbook import *
from fastai.collab import *
from fastai.tabular.all import *

### Data

In [3]:
path = untar_data(URLs.ML_100k)

In [4]:
ratings = pd.read_csv(path/'u.data', delimiter='\t', header=None, names=['user', 'movie', 'rating', 'timestamp'])
movies = pd.read_csv(path/'u.item', delimiter='|', encoding='latin-1', usecols=(0,1), names=('movie', 'title'), header=None)
ratings = ratings.merge(movies)
ratings.head()

Unnamed: 0,user,movie,rating,timestamp,title
0,196,242,3,881250949,Kolya (1996)
1,63,242,3,875747190,Kolya (1996)
2,226,242,5,883888671,Kolya (1996)
3,154,242,3,879138235,Kolya (1996)
4,306,242,5,876503793,Kolya (1996)


### Creating DataLoaders

In [5]:
dls = CollabDataLoaders.from_df(ratings, item_name='title', bs=64)

In [6]:
embs = get_emb_sz(dls)
embs

[(944, 74), (1665, 102)]

### Creating a model for Cross Entropy Loss

In [7]:
class CollabCEL(Module):
    def __init__(self, user_embsize, item_embsize, n_activations=100, y_range=(0,5.5)):
        self.user_factors = Embedding(*user_embsize)
        self.item_factors = Embedding(*item_embsize)
        self.layers = nn.Sequential(
            nn.Linear(user_embsize[1]+item_embsize[1], n_activations),
            nn.ReLU(),
            nn.Linear(n_activations, 6)
        )
        self.y_range = y_range
    
    def forward(self, x):
        embs = self.user_factors(x[:,0]), self.item_factors(x[:,1])
        concat_embs = torch.cat(embs, dim=1)
        net_result = self.layers(concat_embs)
        return sigmoid_range(net_result, *self.y_range)

### Learning

In [15]:
model = CollabCEL(*embs)
learn = Learner(dls, model, loss_func=CrossEntropyLossFlat())

In [16]:
learn.fit_one_cycle(10, 1e-2, wd=0.05)

epoch,train_loss,valid_loss,time
0,1.30096,1.303658,00:15
1,1.272742,1.301013,00:15
2,1.264107,1.298234,00:15
3,1.277428,1.281321,00:15
4,1.24829,1.262478,00:15
5,1.204677,1.255076,00:15
6,1.177123,1.246983,00:15
7,1.124074,1.24857,00:15
8,1.085737,1.260726,00:14
9,1.03265,1.271948,00:15
