# Introduction

This Notebooks is a join notebook from both the prepare_data and pytorch-bst in order to be run in google colab.

# Prepare data section

In [None]:
!pip install pytorch_lightning



In [None]:
import pandas as pd
import torch
import pytorch_lightning as pl
from tqdm import tqdm
import torchmetrics
import math
from urllib.request import urlretrieve
from zipfile import ZipFile
import os
import torch.nn as nn
import numpy as np
from math import sqrt

In [None]:
!nvidia-smi

Thu May  5 05:57:34 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
torch.__version__

'1.11.0+cu113'

## Settings

In [None]:
WINDOW_SIZE = 20

## Data

In [None]:
raw_df = pd.read_csv('Office_Products.csv', usecols=['rating', 'reviewerID', 'product_id', 'date'])

In [None]:
# Prototyping 목적으로 50만 개의 행만 이용
raw_df = raw_df.iloc[:500000, :]
raw_df.rename(columns={'reviewerID': 'user_id'}, inplace=True)
raw_df.loc[ :, 'rating'] = raw_df.loc[:, 'rating'].apply(lambda x: float(x))
raw_df.head()

Unnamed: 0,rating,user_id,product_id,date
0,3.0,A2WJLOXXIB7NF3,140503528,1162512000.0
1,5.0,A1RKICUK0GG6VF,140503528,1147133000.0
2,5.0,A1QA5E50M398VW,140503528,1142035000.0
3,5.0,A3N0HBW8IP8CZQ,140503528,980294400.0
4,5.0,A1K1JW1C5CUSUZ,140503528,964915200.0


## Make Sequence

In [None]:
df_group = raw_df.sort_values(by=['date']).groupby('user_id')

df = pd.DataFrame(
    data={
        'user_id': list(df_group.groups.keys()),
        'product_id': list(df_group.product_id.apply(list)),
        'rating': list(df_group.rating.apply(list)),
        'date': list(df_group.date.apply(list)),
    }
)

In [None]:
df.tail()

Unnamed: 0,user_id,product_id,rating,date
410484,AZZXNFW30OVPU,[1601064993],[4.0],[1506384000.0]
410485,AZZY5J9F8H7D4,[B0002ABA8E],[5.0],[1519084800.0]
410486,AZZYGB3DSML0J,"[B00006IFMG, B00006IE2K, B00006IE2J, B00006RSO4]","[5.0, 1.0, 1.0, 3.0]","[1434326400.0, 1458345600.0, 1458345600.0, 146..."
410487,AZZYW4YOE1B6E,[B00006IFIK],[4.0],[1362614400.0]
410488,AZZZD0GTOGRYT,[B00000JZKB],[5.0],[1496188800.0]


In [None]:
sequence_length = 3
step_size = 1


def create_sequences(values, window_size, step_size):
    sequences = []
    start_index = 0
    while True:
        end_index = start_index + window_size
        seq = values[start_index:end_index]
        if len(seq) < window_size:
            seq = values[-window_size:]
            if len(seq) == window_size:
                sequences.append(seq)
            break
        sequences.append(seq)
        start_index += step_size
    return sequences


df.product_id = df.product_id.apply(
    lambda ids: create_sequences(ids, sequence_length, step_size)
)

df.rating = df.rating.apply(
    lambda ids: create_sequences(ids, sequence_length, step_size)
)

del df['date']

In [None]:
df_transformed = df.explode(['product_id', 'rating'], ignore_index=True)
df_transformed.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 441528 entries, 0 to 441527
Data columns (total 3 columns):
 #   Column      Non-Null Count   Dtype 
---  ------      --------------   ----- 
 0   user_id     441528 non-null  object
 1   product_id  46732 non-null   object
 2   rating      46732 non-null   object
dtypes: object(3)
memory usage: 10.1+ MB


In [None]:
df_transformed.dropna(axis=0, how='any', inplace=True)
df_transformed.isnull().any()

user_id       False
product_id    False
rating        False
dtype: bool

In [None]:
df_transformed.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 46732 entries, 16 to 441525
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   user_id     46732 non-null  object
 1   product_id  46732 non-null  object
 2   rating      46732 non-null  object
dtypes: object(3)
memory usage: 1.4+ MB


In [None]:
df_transformed.product_id = df_transformed.product_id.apply(lambda x: ",".join(x))
df_transformed.rating = df_transformed.rating.apply(lambda x: ",".join([str(v) for v in x]))
df_transformed.rename(columns={'product_id': 'seq_product_ids', 'rating': 'seq_ratings'}, inplace=True)

In [None]:
random_selection = np.random.rand(len(df_transformed.index)) <= 0.85
train_data = df_transformed[random_selection]
test_data = df_transformed[~random_selection]

train_data.to_csv('train_data.csv', index=False, sep='|')
test_data.to_csv('test_data.csv', index=False, sep='|')

In [None]:
train_data

Unnamed: 0,user_id,seq_product_ids,seq_ratings
16,A0220159ZRNBTRKLG08H,"8862930003,B00006IE7J,B00005249G","5.0,5.0,5.0"
17,A0220159ZRNBTRKLG08H,"B00006IE7J,B00005249G,B00006IEJC","5.0,5.0,5.0"
18,A0220159ZRNBTRKLG08H,"B00006IE7J,B00005249G,B00006IEJC","5.0,5.0,5.0"
31,A03492194F0T997EZQ04,"B00005249G,B00006JNNE,B00006IE7J","5.0,5.0,5.0"
49,A05012776MTIS8L40R3I,"B0000AQOCO,B00006IBK2,B0002ABCIC","5.0,5.0,5.0"
...,...,...,...
441507,AZZRRYBQG57LF,"B00006IQCV,B00006IFIK,B00006IDQT","5.0,5.0,5.0"
441508,AZZRRYBQG57LF,"B00006IQCV,B00006IFIK,B00006IDQT","5.0,5.0,5.0"
441523,AZZYGB3DSML0J,"B00006IFMG,B00006IE2K,B00006IE2J","5.0,1.0,1.0"
441524,AZZYGB3DSML0J,"B00006IE2K,B00006IE2J,B00006RSO4","1.0,1.0,3.0"


In [None]:
train_data.iloc[0].user_id

'A0220159ZRNBTRKLG08H'

# BST Implementation and training

In [None]:
import pandas as pd
import torch
import pytorch_lightning as pl
from tqdm import tqdm
import torchmetrics
import math
from urllib.request import urlretrieve
from zipfile import ZipFile
import os
import torch.nn as nn
import numpy as np
from torchtext.vocab import vocab
from collections import Counter

## Pytorch dataset

In [None]:
data_keys = {
    'user_id': list(df_transformed.user_id.unique()),
    'product_id': list(raw_df.product_id.unique()), #이렇게 해야 개수가 훨 줄어듬
}

In [None]:
len(data_keys['product_id'])

4918

In [None]:
import pandas as pd
import torch
import torch.utils.data as data
from torchvision import transforms
import ast
from torch.nn.utils.rnn import pad_sequence

class AmazonDataset(data.Dataset):
    """Movie dataset."""

    def __init__(
        self, ratings_file,test=False
    ):
        """
        Args:
            csv_file (string): Path to the csv file with user,past,future.
        """
        self.ratings_frame = pd.read_csv(
            ratings_file,
            delimiter="|",
            # iterator=True,
        )
        self.test = test
        self.product_lookup = vocab(Counter(data_keys['product_id']))
        self.user_lookup = vocab(Counter(data_keys['user_id']))

    def __len__(self):
        return len(self.ratings_frame)
    
    def encoding(self, value, table):
        return table[value]

    def __getitem__(self, idx):
        data = self.ratings_frame.iloc[idx]
        user_id = self.encoding(data.user_id, self.user_lookup)
        product_history = list(map(lambda x: int(self.encoding(x, self.product_lookup)), data.seq_product_ids.split(',')))
        product_history_ratings = list(map(float, data.seq_ratings.split(',')))
        target_product_id = product_history[-1:][0]
        target_product_rating = product_history_ratings[-1:][0]
        
        product_history = torch.LongTensor(product_history[:-1])
        product_history_ratings = torch.FloatTensor(product_history_ratings[:-1])

        #target_id는 encoding no

        
        # sex = data.sex
        # age_group = data.age_group
        # occupation = data.occupation
        
        return user_id, product_history, target_product_id,  product_history_ratings, target_product_rating#, sex, age_group, occupation

In [None]:
class PositionalEmbedding(nn.Module):
    """
    Computes positional embedding following "Attention is all you need"
    """

    def __init__(self, max_len, d_model):
        super().__init__()

        # Compute the positional encodings once in log space.
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)


class BST(pl.LightningModule):
    def __init__(
        self, args=None
    ):
        super().__init__()
        super(BST, self).__init__()
        
        self.save_hyperparameters()
        self.args = args
        #-------------------
        # Embedding layers
        ##Users 
        self.embeddings_user_id = nn.Embedding(
            len(data_keys['user_id']), int(math.sqrt(len(data_keys['user_id'])))+1
        )


        self.embeddings_position  = nn.Embedding(
           sequence_length, int(math.sqrt(len(data_keys['product_id'])))
        )

        self.embeddings_product_id = nn.Embedding(
            len(data_keys['product_id']), int(math.sqrt(len(data_keys['product_id'])))
        )
        
        
        self.positional_embedding = PositionalEmbedding(sequence_length, 63)
        
        # Network
        self.transfomerlayer = nn.TransformerEncoderLayer(63, 3, dropout=0.2)
        self.linear = nn.Sequential(
            nn.Linear(
                315,
                512,
            ),
            # nn.LeakyReLU(),
            # nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
        )
        self.criterion = torch.nn.MSELoss()
        self.mae = torchmetrics.MeanAbsoluteError()
        self.mse = torchmetrics.MeanSquaredError()
        


    def encode_input(self,inputs):
        user_id, product_history, target_product_id,  product_history_ratings, target_product_rating = inputs
               
        #MOVIES
        product_history = self.embeddings_product_id(product_history)
        target_product = self.embeddings_product_id(target_product_id)
        
        positions = torch.arange(0,sequence_length-1,1,dtype=int,device=self.device)
        positions = self.embeddings_position(positions)

        
        encoded_sequence_products_with_poistion_and_rating = (product_history + positions) *  product_history_ratings.reshape(-1, sequence_length-1,1)
        
        target_product = torch.unsqueeze(target_product, 1)
        transfomer_features = torch.cat((encoded_sequence_products_with_poistion_and_rating, target_product),dim=1)
        
        #USERS
        user_features = self.embeddings_user_id(user_id)
        
        return transfomer_features, user_features, target_product_rating.float()
    
    def forward(self, batch):
        transfomer_features, user_features, target_product_rating = self.encode_input(batch)
        transfomer_features = self.positional_embedding(transfomer_features)
        transformer_output = self.transfomerlayer(transfomer_features)
        transformer_output = torch.flatten(transformer_output,start_dim=1)
        
        #Concat with other features
        features = torch.cat((transformer_output,user_features),dim=1)
        output = self.linear(features)
        return output, target_product_rating
        
    def training_step(self, batch, batch_idx):
        out, target_product_rating = self(batch)
        out = out.flatten()
        loss = self.criterion(out, target_product_rating)
        
        mae = self.mae(out, target_product_rating)
        mse = self.mse(out, target_product_rating)
        rmse =torch.sqrt(mse)
        self.log(
            "train/mae", mae, on_step=True, on_epoch=False, prog_bar=False
        )
        
        self.log(
            "train/rmse", rmse, on_step=True, on_epoch=False, prog_bar=False
        )
        
        self.log("train/step_loss", loss, on_step=True, on_epoch=False, prog_bar=False)
        return loss
    
    def validation_step(self, batch, batch_idx):
        out, target_product_rating = self(batch)
        out = out.flatten()
        loss = self.criterion(out, target_product_rating)
        
        mae = self.mae(out, target_product_rating)
        mse = self.mse(out, target_product_rating)
        rmse =torch.sqrt(mse)
        
        return {"val_loss": loss, "mae": mae.detach(), "rmse":rmse.detach()}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_mae = torch.stack([x["mae"] for x in outputs]).mean()
        avg_rmse = torch.stack([x["rmse"] for x in outputs]).mean()
        
        print("epoch end: 3losses are {0}, {1}, {2}".format(avg_loss, avg_mae, avg_rmse))
        self.log("val/loss", avg_loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/mae", avg_mae, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/rmse", avg_rmse, on_step=False, on_epoch=True, prog_bar=False)

    # def test_step(self, batch, batch_idx):
    #     out, target_product_rating = self(batch)
    #     out = out.flatten()
    #     loss = self.criterion(out, target_product_rating)
        
    #     mae = self.mae(out, target_product_rating)
    #     mse = self.mse(out, target_product_rating)
    #     rmse =torch.sqrt(mse)
        
    #     return {"users": loss, "top14": mae.detach(), "rmse":rmse.detach()}

    def test_epoch_end(self, outputs):
        users = torch.cat([x["users"] for x in outputs])
        y_hat = torch.cat([x["top14"] for x in outputs])
        users = users.tolist()
        y_hat = y_hat.tolist()
        
        data = {"users": users, "top14": y_hat}
        df = pd.DataFrame.from_dict(data)
        print(len(df))
        df.to_csv("lightning_logs/predict.csv", index=False)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.0005)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--learning_rate", type=float, default=0.01)
        return parser

    ####################
    # DATA RELATED HOOKS
    ####################

    def setup(self, stage=None):
        print("Loading datasets")
        self.train_dataset = AmazonDataset("./train_data.csv")
        self.val_dataset = AmazonDataset("./test_data.csv")
        self.test_dataset = AmazonDataset("./test_data.csv")
        print("Done")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=128,
            shuffle=True,
            num_workers=os.cpu_count(),
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=os.cpu_count(),
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=os.cpu_count(),
        )
        
model = BST()
trainer = pl.Trainer(gpus=1,max_epochs=50)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /content/lightning_logs


Loading datasets
Done


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                  | Type                    | Params
------------------------------------------------------------------
0 | embeddings_user_id    | Embedding               | 2.0 M 
1 | embeddings_position   | Embedding               | 210   
2 | embeddings_product_id | Embedding               | 344 K 
3 | positional_embedding  | PositionalEmbedding     | 189   
4 | transfomerlayer       | TransformerEncoderLayer | 276 K 
5 | linear                | Sequential              | 293 K 
6 | criterion             | MSELoss                 | 0     
7 | mae                   | MeanAbsoluteError       | 0     
8 | mse                   | MeanSquaredError        | 0     
------------------------------------------------------------------
2.9 M     Trainable params
0         Non-trainable params
2.9 M     Total params
11.568    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

epoch end: 3losses are 21.7069149017334, 4.57490873336792, 4.659060478210449


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.8407968282699585, 0.48167452216148376, 0.9093762040138245


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.7506809830665588, 0.5410512089729309, 0.8592971563339233


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.7235715389251709, 0.5328536033630371, 0.8440684080123901


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.6704197525978088, 0.5450663566589355, 0.8127197027206421


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.6312693953514099, 0.5295265316963196, 0.7890890836715698


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.597585916519165, 0.5349047183990479, 0.7682714462280273


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.5575977563858032, 0.47934117913246155, 0.7410440444946289


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.5216544270515442, 0.47559404373168945, 0.7163107991218567


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.5073041319847107, 0.471591591835022, 0.7055206298828125


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.4722200632095337, 0.40127289295196533, 0.6771570444107056


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.4845844507217407, 0.4734114110469818, 0.6880738735198975


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.43286505341529846, 0.3878542482852936, 0.6487045884132385


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.42700013518333435, 0.387859582901001, 0.643481433391571


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.4241846203804016, 0.3562055230140686, 0.6400998830795288


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.41325363516807556, 0.3748793601989746, 0.6333072781562805


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.42421936988830566, 0.349624902009964, 0.6400275230407715


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.40669894218444824, 0.35755014419555664, 0.6268888711929321


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.43306758999824524, 0.3802315294742584, 0.6475277543067932


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.4121438264846802, 0.3514273762702942, 0.630022406578064


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3998691737651825, 0.34293726086616516, 0.6220071911811829


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.39896875619888306, 0.32624509930610657, 0.620270848274231


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.39045241475105286, 0.32390761375427246, 0.6134487390518188


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3980351984500885, 0.326850563287735, 0.6193607449531555


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3995009660720825, 0.3340894877910614, 0.6208756566047668


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3933144509792328, 0.32052209973335266, 0.6158266067504883


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3952414095401764, 0.3216024935245514, 0.6174268126487732


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3869721591472626, 0.3190995752811432, 0.6106406450271606


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3847852051258087, 0.31960445642471313, 0.6092322468757629


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3973035514354706, 0.33032697439193726, 0.619308590888977


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.394485741853714, 0.32064247131347656, 0.6169149279594421


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.39460504055023193, 0.3263705372810364, 0.6169247627258301


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.38897550106048584, 0.32348236441612244, 0.6126658916473389


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.38131481409072876, 0.3080388009548187, 0.6062115430831909


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3876313865184784, 0.33835673332214355, 0.6125210523605347


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3890959322452545, 0.3152131140232086, 0.612582802772522


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.38059449195861816, 0.31464850902557373, 0.6068704724311829


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.39150962233543396, 0.34182605147361755, 0.6159229278564453


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.39248451590538025, 0.3246217370033264, 0.6156530976295471


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.38664528727531433, 0.3130730986595154, 0.6103042364120483


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.38967081904411316, 0.32001250982284546, 0.6125599145889282


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3792090117931366, 0.3048465847969055, 0.6042056679725647


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3828907012939453, 0.3073230981826782, 0.6075954437255859


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.38492825627326965, 0.31055399775505066, 0.60908043384552


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3855357766151428, 0.311642587184906, 0.6100123524665833


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3879588842391968, 0.314502477645874, 0.6114062070846558


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.38427841663360596, 0.30367034673690796, 0.6082412600517273


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.3847707509994507, 0.3068960905075073, 0.6087231636047363


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.38759660720825195, 0.3319183588027954, 0.611766517162323


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.37988993525505066, 0.3051966428756714, 0.604941189289093


Validation: 0it [00:00, ?it/s]

epoch end: 3losses are 0.37654149532318115, 0.30337631702423096, 0.6024641394615173


In [None]:
'''
  different approach - use model for classification
'''
class PositionalEmbedding(nn.Module):
    """
    Computes positional embedding following "Attention is all you need"
    """

    def __init__(self, max_len, d_model):
        super().__init__()

        # Compute the positional encodings once in log space.
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)


class BST_amazon(pl.LightningModule):
    def __init__(
        self, args=None
    ):
        super().__init__()
        super(BST_amazon, self).__init__()
        
        self.save_hyperparameters()
        self.args = args
        #-------------------
        # Embedding layers
        ##Users 
        self.embeddings_user_id = nn.Embedding(
            len(data_keys['user_id']), int(math.sqrt(len(data_keys['user_id'])))+1
        )


        self.embeddings_position  = nn.Embedding(
           sequence_length, int(math.sqrt(len(data_keys['product_id'])))
        )

        self.embeddings_product_id = nn.Embedding(
            len(data_keys['product_id']), int(math.sqrt(len(data_keys['product_id'])))
        )
        
        
        self.positional_embedding = PositionalEmbedding(sequence_length-1, 63)
        
        # Network
        self.transfomerlayer = nn.TransformerEncoderLayer(63, 3, dropout=0.2)
        
        self.linear = nn.Sequential(
            nn.Linear(
                252,
                1024,
            ),
            # nn.LeakyReLU(),
            # nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.Linear(1024, 2048),
            nn.LeakyReLU(),
            nn.Linear(2048, len(data_keys['product_id'])),
        )
        self.criterion = torch.nn.CrossEntropyLoss()
        # self.mae = torchmetrics.MeanAbsoluteError()
        # self.mse = torchmetrics.MeanSquaredError()
        


    def encode_input(self,inputs):
        user_id, product_history, target_product_id,  product_history_ratings, target_product_rating = inputs
               
        #MOVIES
        product_history = self.embeddings_product_id(product_history)
        # target_product = self.embeddings_product_id(target_product_id)
        
        positions = torch.arange(0,sequence_length-1,1,dtype=int,device=self.device)
        positions = self.embeddings_position(positions)

        
        encoded_sequence_products_with_poistion_and_rating = (product_history + positions) *  product_history_ratings.reshape(-1, sequence_length-1,1)#Yet to multiply by rating
        
        # target_product = torch.unsqueeze(target_product, 1)
        # transfomer_features = torch.cat((encoded_sequence_products_with_poistion_and_rating, target_product),dim=1)
        transfomer_features = encoded_sequence_products_with_poistion_and_rating
        #USERS
        user_features = self.embeddings_user_id(user_id)
        
        return transfomer_features, user_features, target_product_id
    
    def forward(self, batch):
        transfomer_features, user_features, target_product_id = self.encode_input(batch)
        transfomer_features = self.positional_embedding(transfomer_features)
        transformer_output = self.transfomerlayer(transfomer_features)
        transformer_output = torch.flatten(transformer_output,start_dim=1)
        
        #Concat with other features
        features = torch.cat((transformer_output,user_features),dim=1)
        output = self.linear(features)
        return output, target_product_id
        
    def training_step(self, batch, batch_idx):
        out, target_product_id = self(batch)
        loss = self.criterion(out, target_product_id)
        
        self.log("train/step_loss", loss, on_step=True, on_epoch=False, prog_bar=False)
        return loss
    
    def validation_step(self, batch, batch_idx):
        out, target_product_id = self(batch)
        loss = self.criterion(out, target_product_id)
        
        
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        
        
        print("epoch end: loss {0}".format(avg_loss))
        self.log("val/loss", avg_loss, on_step=False, on_epoch=True, prog_bar=False)
        

    # def test_step(self, batch, batch_idx):
    #     out, target_product_rating = self(batch)
    #     out = out.flatten()
    #     loss = self.criterion(out, target_product_rating)
        
    #     mae = self.mae(out, target_product_rating)
    #     mse = self.mse(out, target_product_rating)
    #     rmse =torch.sqrt(mse)
        
    #     return {"users": loss, "top14": mae.detach(), "rmse":rmse.detach()}

    def test_epoch_end(self, outputs):
        users = torch.cat([x["users"] for x in outputs])
        y_hat = torch.cat([x["top14"] for x in outputs])
        users = users.tolist()
        y_hat = y_hat.tolist()
        
        data = {"users": users, "top14": y_hat}
        df = pd.DataFrame.from_dict(data)
        print(len(df))
        df.to_csv("lightning_logs/predict.csv", index=False)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.0005)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--learning_rate", type=float, default=0.01)
        return parser

    ####################
    # DATA RELATED HOOKS
    ####################

    def setup(self, stage=None):
        print("Loading datasets")
        self.train_dataset = AmazonDataset("./train_data.csv")
        self.val_dataset = AmazonDataset("./test_data.csv")
        self.test_dataset = AmazonDataset("./test_data.csv")
        print("Done")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=128,
            shuffle=True,
            num_workers=os.cpu_count(),
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=os.cpu_count(),
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=os.cpu_count(),
        )
        
model = BST_amazon()
trainer = pl.Trainer(gpus=1,max_epochs=50)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loading datasets


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                  | Type                    | Params
------------------------------------------------------------------
0 | embeddings_user_id    | Embedding               | 2.0 M 
1 | embeddings_position   | Embedding               | 210   
2 | embeddings_product_id | Embedding               | 344 K 
3 | positional_embedding  | PositionalEmbedding     | 126   
4 | transfomerlayer       | TransformerEncoderLayer | 276 K 
5 | linear                | Sequential              | 12.4 M
6 | criterion             | CrossEntropyLoss        | 0     
------------------------------------------------------------------
15.0 M    Trainable params
0         Non-trainable params
15.0 M    Total params
60.135    Total estimated model params size (MB)


Done


Sanity Checking: 0it [00:00, ?it/s]

epoch end: loss 8.500458717346191


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

epoch end: loss 6.558570861816406


Validation: 0it [00:00, ?it/s]

epoch end: loss 6.285812854766846


Validation: 0it [00:00, ?it/s]

epoch end: loss 5.886375427246094


Validation: 0it [00:00, ?it/s]

epoch end: loss 5.54004430770874


Validation: 0it [00:00, ?it/s]

epoch end: loss 5.353392124176025


Validation: 0it [00:00, ?it/s]

epoch end: loss 5.07771110534668


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.975618839263916


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.917363166809082


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.798571586608887


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.816004276275635


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.776594638824463


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.737738609313965


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.674821376800537


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.6823601722717285


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.68457555770874


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.6806960105896


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.685539722442627


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.661981105804443


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.658028602600098


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.709644317626953


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.67482328414917


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.66140079498291


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.6451520919799805


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.607969760894775


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.639594078063965


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.6683831214904785


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.646440505981445


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.627695560455322


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.679287433624268


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.647072792053223


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.647062301635742


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.6515092849731445


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.663784980773926


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.737728118896484


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.676122665405273


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.6888933181762695


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.666663646697998


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.662160396575928


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.71327543258667


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.705243110656738


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.691624641418457


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.703766822814941


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.760776519775391


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.6937255859375


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.7198100090026855


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.762923717498779


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.795872688293457


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.711773872375488


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.778147220611572


Validation: 0it [00:00, ?it/s]

epoch end: loss 4.689154624938965


In [None]:
def 

SyntaxError: ignored