In [1]:
from pandas import read_csv
from torch.optim import SGD
from torch.utils.data import DataLoader
from src.runner import Runner
from src.loss import MiningOutliersLoss
from src.model import MF
from src.utils import DataConverter, DataProcessor, create_dataset, mine_outliers
from config import DATA_DIR

DF_PATH = f"{DATA_DIR}" \
          f"/DEAM/annotations/annotations per each rater/" \
          f"song_level/static_annotations_songs_1_2000.csv"

In [2]:
columns = ["workerID", "SongId", "Valence"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "rating"]

data_converter = DataConverter(
    original_df=original_df, n_random_users=10, n_ratings_per_random_user=200
)
data_processor = DataProcessor(original_df=data_converter.original_df)

model = MF(
    n_users=data_converter.n_users,
    n_items=data_converter.n_item,
)

criterion = MiningOutliersLoss(data_converter=data_converter, data_processor=data_processor)
optimizer = SGD(model.parameters(), lr=5, weight_decay=1e-5)
runner = Runner(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
)

train_set = create_dataset(data_converter=data_converter)
train_load = DataLoader(train_set, batch_size=1000, shuffle=True)
epochs = 75
for epoch in range(epochs):
    epoch_loss = runner.train(train_loader=train_load)
    print(f"epoch={epoch + 1}, loss={epoch_loss}")

mine_outliers(model=model, data_converter=data_converter)

100%|██████████| 20/20 [00:09<00:00,  2.07batch/s, train_loss=216]
  0%|          | 0/20 [00:00<?, ?batch/s]

epoch=1, loss=11308.120407104492


100%|██████████| 20/20 [00:07<00:00,  2.77batch/s, train_loss=166]
  0%|          | 0/20 [00:00<?, ?batch/s]

epoch=2, loss=7888.20051574707


100%|██████████| 20/20 [00:04<00:00,  4.07batch/s, train_loss=152]
  0%|          | 0/20 [00:00<?, ?batch/s]

epoch=3, loss=6588.058670043945


100%|██████████| 20/20 [00:04<00:00,  4.26batch/s, train_loss=134]
  0%|          | 0/20 [00:00<?, ?batch/s]

epoch=4, loss=5653.770309448242


 80%|████████  | 16/20 [00:04<00:01,  3.28batch/s, train_loss=261]


KeyboardInterrupt: 