In [None]:
import torch
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [None]:
from pathlib import Path
import pandas as pd

data_root = Path("../kkdata3")
for x in data_root.glob("*"):
    print(x)

train_source = pd.read_parquet(data_root / "label_train_source.parquet")
train_target = pd.read_parquet(data_root / "label_train_target.parquet")
test_source = pd.read_parquet(data_root / "label_test_source.parquet")
meta_song = pd.read_parquet(data_root / "meta_song.parquet")

In [None]:
train_source.sort_values(["session_id", "listening_order"], inplace=True)
train_target.sort_values(["session_id", "listening_order"], inplace=True)
test_source.sort_values(["session_id", "listening_order"], inplace=True)

In [None]:
# map song_id to song_index to save memory and speed up
meta_song["song_index"] = meta_song.index
train_source = train_source.merge(
    meta_song[["song_id", "song_index"]], on="song_id", how="left"
)
train_target = train_target.merge(
    meta_song[["song_id", "song_index"]], on="song_id", how="left"
)
test_source = test_source.merge(
    meta_song[["song_id", "song_index"]], on="song_id", how="left"
)
del train_source["song_id"]
del train_target["song_id"]
del test_source["song_id"]

In [None]:
# return n+1 column song id
def getTrainData(df, n=2):
    df = df.copy()
    # gen n song id be the dataset
    for i in range(1, n + 1):
        df[f"next{i}_song_id"] = df["song_index"].shift(-i)

    # check if last song id is in the same session
    df[f"next{n}_session_id"] = df["session_id"].shift(-n)
    df = df.query(f"session_id == next{n}_session_id")

    # only get the song_id and next1_song_id, next2_song_id, next3_song_id... column
    df = df[["song_index"] + [f"next{i}_song_id" for i in range(1, n + 1)]]
    return df

In [None]:
trainX = getTrainData(train_source, n=19)
trainY = getTrainData(train_target, n=4)

trainX["song_index"] = trainX["song_index"].astype("float64")
trainY["song_index"] = trainY["song_index"].astype("float64")

trainX.reset_index(drop=True, inplace=True)
trainY.reset_index(drop=True, inplace=True)

sos = meta_song["song_index"].max() + 1
eos = meta_song["song_index"].max() + 2
# get last 5 col of trainX
src = trainX.iloc[:, -5:]
src.insert(0, "sos", sos)
src["eos"] = eos

# trainY concat with trainX
tgt = trainY
tgt.insert(0, "sos", sos)
tgt["eos"] = eos

# check if src and tgt shape is the same
src.shape == tgt.shape

In [None]:
class Dataset(Dataset):
    def __init__(self, source_data, target_data):
        self.source_data = source_data
        self.target_data = target_data

    def __len__(self):
        return len(self.source_data)

    def __getitem__(self, idx):
        src_tensor = torch.LongTensor(self.source_data.iloc[idx].values)
        tgt_tensor = torch.LongTensor(self.target_data.iloc[idx].values)
        return src_tensor, tgt_tensor

In [None]:
from sklearn.model_selection import train_test_split

# Split the data into training, validation, and test sets
src, src_val, tgt, tgt_val = train_test_split(src, tgt, test_size=0.1, random_state=42)

src_val = src_val[:16]
tgt_val = tgt_val[:16]

train_dataset = Dataset(source_data=src, target_data=tgt)
validation_dataset = Dataset(source_data=src_val, target_data=tgt_val)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=False)

In [None]:
train_source["song_index"].nunique()

In [None]:
MAX_SEQ_LEN = 7
VOCAB_SIZE = meta_song["song_index"].max() + 3  # 1030711 + sos + eos + 0
EMBEDDING_DIM = 32
NHEAD = 8
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4

In [None]:
import math
from torch import Tensor


class PositionalEncoding(nn.Module):
    def __init__(
        self,
        d_model: int,
        dropout: float = 0.1,
        max_len: int = 5000,
        batch_first: bool = False,
    ):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)
        self.batch_first = batch_first

    def forward(self, x: Tensor) -> Tensor:
        if self.batch_first:
            x = x.transpose(0, 1)
            x = x + self.pe[: x.size(0)]
            return self.dropout(x.transpose(0, 1))
        else:
            x = x + self.pe[: x.size(0)]
            return self.dropout(x)

In [None]:
class TransformerModel(nn.Module):
    def __init__(self):
        super(TransformerModel, self).__init__()

        self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM, padding_idx=0)
        self.transformer = nn.Transformer(
            d_model=EMBEDDING_DIM,
            nhead=NHEAD,
            num_encoder_layers=NUM_ENCODER_LAYERS,
            num_decoder_layers=NUM_DECODER_LAYERS,
            batch_first=True,
        )
        self.pos_embedding = PositionalEncoding(
            EMBEDDING_DIM, dropout=0.1, max_len=MAX_SEQ_LEN, batch_first=True
        )
        self.fc = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)

    def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, tgt_mask):
        _ = self.embedding(src)
        src = self.pos_embedding(_)

        _ = self.embedding(tgt)
        tgt = self.pos_embedding(_)

        output = self.transformer(
            src,
            tgt,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
        )

        return self.fc(output)


# detect where the padding value is
def gen_padding_mask(src, pad_idx=0.0):
    # pad_mask = (src == pad_idx
    return src.eq(pad_idx)


# triu mask for decoder
def gen_mask(seq):
    seq_len = seq.size(1)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask

In [None]:
!export CUDA_LAUNCH_BLOCKING=1

In [None]:
def get_index(pred, dim=2):
    return pred.clone().argmax(dim=dim)

In [None]:
def metrics(pred: list, target: list) -> float:
    """
    pred: list of strings
    target: list of strings

    return: accuracy(%)
    """
    if len(pred) != len(target):
        raise ValueError("length of pred and target must be the same")
    correct = 0
    for i in range(len(pred)):
        if pred[i] == target[i]:
            correct += 1
    return correct / len(pred) * 100

In [None]:
next(iter(validation_loader))[1].shape

In [None]:
def validation(dataloader, model, device, logout=False, dataset="test"):
    pred_str_list = []
    tgt_str_list = []
    input_str_list = []
    losses = []
    ce_loss = nn.CrossEntropyLoss()
    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)
        # An all pad token tensor with the same shape as tgt and the first token is <sos>
        tgt_input = torch.full_like(tgt, fill_value=0)  # 0 is the pad token
        tgt_input[:, 0] = sos  # -1 is the <sos> token
        for i in range(tgt.shape[1] - 1):
            src_pad_mask = gen_padding_mask(src, pad_idx=0).to(device)
            tgt_pad_mask = gen_padding_mask(tgt_input, pad_idx=0).to(device)
            tgt_mask = gen_mask(tgt_input).to(device)
            pred = model(
                src=src,
                tgt=tgt_input,
                src_padding_mask=src_pad_mask,
                tgt_padding_mask=tgt_pad_mask,
                tgt_mask=tgt_mask,
            )
            pred_idx = get_index(pred)
            tgt_input[:, i + 1] = pred_idx[:, i]

        for i in range(tgt.shape[0]):
            pred_str_list.append(tgt_input[i].tolist())
            tgt_str_list.append(tgt[i].tolist())
            input_str_list.append(src[i].tolist())
            if logout:
                print("=" * 30)
                print(f"input: {input_str_list[-1]}")
                print(f"pred: {pred_str_list[-1]}")
                print(f"target: {tgt_str_list[-1]}")
        loss = ce_loss(pred[:, :-1, :].permute(0, 2, 1), tgt[:, 1:])
        losses.append(loss.item())

    avg_loss = sum(losses) / len(losses)

    print(
        f"{dataset}_acc: {metrics(pred_str_list, tgt_str_list):.2f}",
        f"{dataset}_loss: {avg_loss:.2f}",
        end=" | ",
    )
    print(f"[pred: {pred_str_list[0]} target: {tgt_str_list[0]}]")
    return avg_loss

In [None]:
device

In [None]:
# train_x = torch.randint(
#     0, VOCAB_SIZE, (100, MAX_SEQ_LEN), dtype=torch.long, device=device
# )
# train_y = torch.randint(
#     0, VOCAB_SIZE, (100, MAX_SEQ_LEN), dtype=torch.long, device=device
# )
import matplotlib.pyplot as plt


model = TransformerModel().to(device=device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

EPOCHS = 20
train_losses = []
with torch.autograd.detect_anomaly():
    for epoch in range(EPOCHS):
        model.train()
        i = 0
        for train_x, train_y in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{EPOCHS}"):
            train_x, train_y = train_x.to(device=device), train_y.to(device=device)

            optimizer.zero_grad()
            src_pad_mask = gen_padding_mask(train_x).to(device)
            tgt_pad_mask = gen_padding_mask(train_y).to(device)
            tgt_mask = gen_mask(train_y).to(device)

            output = model(
                train_x,
                train_y,
                src_padding_mask=src_pad_mask,
                tgt_padding_mask=tgt_pad_mask,
                tgt_mask=tgt_mask,
            )

            # print(output[:, :-1, :].permute(0, 2, 1).shape, train_y[:, 1:].shape)

            loss = criterion(output.permute(0, 2, 1), train_y)
            # loss = criterion(output[:, :-1, :].permute(0, 2, 1), train_y[:, 1:])
            # loss = criterion(output.reshape(-1, VOCAB_SIZE), train_y.reshape(-1))
            loss.backward()
            optimizer.step()
            if i % 500 == 0:
                train_losses.append(loss.item())
                plt.figure(figsize=(10, 5))
                plt.plot(train_losses, label="Training Loss")
                plt.ylabel("CrossEntropy Loss")
                plt.title("Training Loss Curve")
                plt.savefig("result.png")
                # Close the figure to prevent it from being displayed
                plt.close()
            i += 1
        # test
        model.eval()
        with torch.no_grad():
            _ = validation(validation_loader, model, device)

        print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {loss.item()}")