In [1]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 

import numpy as np
import torch
print("PyTorch version:",torch.__version__)
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
      print(f"CUDA GPU {i+1}: {torch.cuda.get_device_name(i)} [Compute Capability: {torch.cuda.get_device_capability(0)[0]}.{torch.cuda.get_device_capability(0)[1]}]")
    device = torch.device('cuda')
    kwargs = {'num_workers': 8, 'pin_memory': True}
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device('cpu')
    print("CUDA GPU is not available. :(")  
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
print ("PyTorch Lightning version:",pl.__version__)
    
import scipy.sparse as sp
from argparse import Namespace

from utilities.custom_lightning import CSVProfiler

import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.debug("Logging enabled at DEBUG level.")

from constants import (SEED, DATA_DIR, LOG_DIR, TRAIN_DATA_PATH, VAL_DATA_PATH, TEST_DATA_PATH)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

PyTorch version: 1.4.0
CUDA GPU 1: GeForce RTX 2080 Ti [Compute Capability: 7.5]


DEBUG:root:Logging enabled at DEBUG level.


PyTorch Lightning version: 0.7.3


In [2]:
NAME = r'AdamUXML' 
SAVE_PATH = DATA_DIR+r'/'+NAME+r'.pt'
PROFILE_PATH = LOG_DIR+'\\'+NAME+'\\profile.csv'

In [3]:
class Interactions(Dataset):
    """
    Create interactions matrix.
    """

    def __init__(self, matrix):
        self.matrix = matrix
        self.n_users = self.matrix.shape[0]
        self.n_items = self.matrix.shape[1]

    def __getitem__(self, index):
        row = self.matrix.row[index]
        col = self.matrix.col[index]
        val = self.matrix.data[index]
        return (row, col), val

    def __len__(self):
        return self.matrix.nnz
    
interaction = Interactions

In [4]:
class TestingCallbacks(pl.Callback):
    def on_test_start(self, trainer, pl_module):
        global y_hat 
        y_hat = sp.dok_matrix((hparams.total_users, hparams.total_items), dtype=np.float32) 
    def on_test_end(self, trainer, pl_module):
        logging.debug(f"Non-zero values in prediction matrix: {y_hat.nnz:,}")
        sp.save_npz(DATA_DIR+NAME+r'-y_hat.npz',y_hat.tocoo())



In [5]:
class AdamUXML(pl.LightningModule):
    def __init__(self, hparams):
        super(AdamUXML, self).__init__() 
        self.hparams = hparams
        self.user_factors = nn.Embedding(hparams.total_users, hparams.n_factors, sparse=hparams.sparse)
        self.item_factors = nn.Embedding(hparams.total_items, hparams.n_factors, sparse=hparams.sparse)
        self.user_biases = nn.Embedding(hparams.total_users, 1, sparse=hparams.sparse)
        self.item_biases = nn.Embedding(hparams.total_items, 1, sparse=hparams.sparse)
        self.dropout = nn.Dropout(p=self.hparams.dropout_p)
        
    def forward(self, users, items):            
        
        user_factors_users = self.user_factors(users)
        item_factors_items = self.item_factors(items)        
        predictions = self.user_biases(users)
        predictions += self.item_biases(items)
        predictions += (self.dropout(user_factors_users) * self.dropout(item_factors_items)).sum(dim=1, keepdim=True)              
        return predictions.squeeze()
    
    def MSELoss(self, logits, labels):
        return nn.functional.mse_loss(logits, labels)
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        row, column = x
        row = row.long()
        column = column.long()
        logits = self.forward(row,column)
        loss = self.MSELoss(logits, y)
        logs = {'train_loss': loss}
        return {'loss': loss, 'log': logs}

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        row, column = x
        row = row.long()
        column = column.long()
        logits = self.forward(row,column)                
        loss = self.MSELoss(logits, y)
        return {'val_loss': loss}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
    
    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        row, column = x
        row = row.long()
        column = column.long()
        logits = self.forward(row,column)                
        loss = self.MSELoss(logits, y)        
       
        logits_array = logits.cpu().numpy()
        r = row.cpu().numpy()
        c = column.cpu().numpy()
        for i in range(len(logits_array)):
            y_hat[r[i],c[i]]=logits_array[i]      
        
        return {'test_loss': loss}
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'MSE': avg_loss}
        print(f"Test Mean Squared Error (MSE): {avg_loss}")                   
        
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
      
    def prepare_data(self):
        self.train_dataset = sp.load_npz(TRAIN_DATA_PATH)
        self.val_dataset = sp.load_npz(VAL_DATA_PATH)
        self.test_dataset = sp.load_npz(TEST_DATA_PATH)
        
    def train_dataloader(self):
        return DataLoader(interaction(self.train_dataset), batch_size=self.hparams.batch_size, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(interaction(self.val_dataset), batch_size=self.hparams.batch_size, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(interaction(self.test_dataset), batch_size=self.hparams.batch_size, shuffle=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate, betas=self.hparams.betas)
        return optimizer

In [12]:
hparams = Namespace(**{
    'batch_size': 1024,
    'learning_rate': 0.001,
    'betas': (0.9, 0.999),
    'n_factors': 20,
    'dropout_p': 0.02,
    'sparse': False,
    'max_epochs': 100,
    'total_users': 177592,
    'total_items': 44780
})

profiler = CSVProfiler(output_path=PROFILE_PATH,verbose=True)
logger = TensorBoardLogger(LOG_DIR, name=NAME)
model = AdamUXML(hparams)
trainer = pl.Trainer(max_epochs=hparams.max_epochs,
                     benchmark=True,
                     profiler=profiler,
                     logger=logger,
                     gpus=1,
                     fast_dev_run=False,
                     callbacks=[TestingCallbacks()])                
trainer.fit(model)

INFO:lightning:GPU available: True, used: True
INFO:lightning:VISIBLE GPUS: 0
INFO:lightning:
  | Name         | Type      | Params
---------------------------------------
0 | user_factors | Embedding | 3 M   
1 | item_factors | Embedding | 895 K 
2 | user_biases  | Embedding | 177 K 
3 | item_biases  | Embedding | 44 K  
4 | dropout      | Dropout   | 0     


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …



HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=541.0, style=Pr…


[1mProfiler Report[0m
                   mean_duration  total_time
on_train_start              0.00        0.00
on_epoch_start              0.00        0.00
get_train_batch             0.00      570.08
on_batch_start              0.00        2.33
model_forward               0.00      218.53
model_backward              0.00      581.54
on_after_backward           0.00        0.33
optimizer_step              0.00      199.49
on_batch_end                0.00        2.52
on_epoch_end                0.00        0.00
on_train_end                0.00        0.00

Profiler output saved to: C:\TensorLogs\AdamUXML\profile.csv


1

In [13]:
trainer.test()

HBox(children=(FloatProgress(value=0.0, description='Testing', layout=Layout(flex='2'), max=540.0, style=Progr…

DEBUG:root:Non-zero values in prediction matrix: 552,255


Test Mean Squared Error (MSE): 0.2457655668258667
--------------------------------------------------------------------------------
TEST RESULTS
{'MSE': 0.2457655668258667, 'avg_test_loss': 0.2457655668258667}
--------------------------------------------------------------------------------



In [8]:
# torch.save(model.state_dict(), SAVE_PATH)

In [9]:
# loaded_model = BasicMatrixFactorization(hparams)
# loaded_model.load_state_dict(torch.load(SAVE_PATH))
# loaded_model.eval()
# print("Model's state_dict:")
# for param_tensor in loaded_model.state_dict():
#     print(param_tensor, "\t", loaded_model.state_dict()[param_tensor].size())

In [10]:
# loaded_model.state_dict()['user_factors.weight']

In [11]:
# loaded_model.state_dict()['item_factors.weight']