In [12]:
import torch
import os
from pandas import read_csv
from torch.optim import SGD
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from config import DATA_DIR, MODELS_DIR
from src.loss import Loss
from src.model import HistogramMF
from src.runner import Runner
from src.create_dataset import create_dataset
from src.data_processor import DataProcessor
from src.data_encoder import DataEncoder

DF_PATH = (
    f"{DATA_DIR}"
    f"/DEAM/annotations/annotations per each rater/"
    f"song_level/static_annotations_songs_1_2000_raw.csv"
)

In [13]:
columns = ["workerID", "SongId", "Valence", "Arousal"]
original_df = read_csv(DF_PATH, skipinitialspace=True, usecols=columns)
original_df.columns = ["user_id", "item_id", "Valence", "Arousal"]
original_df.head()

Unnamed: 0,user_id,item_id,Valence,Arousal
0,6010bbc8e7ef4b21fa38f9c3a9754ef3,2,5,2
1,3c888e77b992ae3cd2adfe16774e23b9,2,2,3
2,2afd218c3aecb6828d2be327f8b9c46f,2,3,3
3,fd5b08ce362d855ca9152a894348130c,2,4,4
4,9c8073214a052e414811b76012df8847,2,2,2


In [14]:
valence_dataframe = original_df[["user_id", "item_id", "Valence"]].copy()
valence_dataframe.columns = ["user_id", "item_id", "rating"]

data_encoder = DataEncoder(original_df=valence_dataframe)
data_processor = DataProcessor(original_df=valence_dataframe)

n_users = valence_dataframe.user_id.nunique()
n_items = valence_dataframe.item_id.nunique()

min_rating = min(valence_dataframe.rating.values)
max_rating = max(valence_dataframe.rating.values)

valence_model = HistogramMF(
    n_users=n_users,
    n_items=n_items,
    data_encoder=data_encoder,
    data_processor=data_processor,
    min_rating=min_rating,
    max_rating=max_rating,
)

if os.path.exists(f"{MODELS_DIR}/DEAM/valence.pt"):
    valence_model.load_state_dict(torch.load(f"{MODELS_DIR}/DEAM/valence.pt"))
else:
    epochs = 100

    criterion = Loss()
    optimizer = SGD(valence_model.parameters(), lr=5, weight_decay=1e-7)
    runner = Runner(model=valence_model, criterion=criterion, optimizer=optimizer)

    train_set = create_dataset(data_encoder=data_encoder)
    train_load = DataLoader(train_set, batch_size=1000, shuffle=True)

    with SummaryWriter(f"runs/DEAM/dev/valence") as writer:
        for epoch in range(epochs):
            epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
            print(f"epoch={epoch + 1}, loss={epoch_loss}")

    torch.save(valence_model.state_dict(), f"{MODELS_DIR}/DEAM/valence.pt")

  0%|          | 0/17464 [00:00<?, ?it/s]

  0%|          | 0/17464 [00:00<?, ?it/s]

100%|██████████| 18/18 [00:02<00:00,  6.07batch/s, train_loss=0.0732]


epoch=1, loss=1.4918756646123428


100%|██████████| 18/18 [00:02<00:00,  8.25batch/s, train_loss=0.0632]


epoch=2, loss=1.1659571734132437


100%|██████████| 18/18 [00:02<00:00,  8.32batch/s, train_loss=0.0589]


epoch=3, loss=1.0371634668152907


100%|██████████| 18/18 [00:02<00:00,  8.12batch/s, train_loss=0.0497]


epoch=4, loss=0.8966528284138646


100%|██████████| 18/18 [00:02<00:00,  8.75batch/s, train_loss=0.0446]


epoch=5, loss=0.7819513364331476


100%|██████████| 18/18 [00:02<00:00,  7.25batch/s, train_loss=0.0412]


epoch=6, loss=0.6838732398460651


100%|██████████| 18/18 [00:02<00:00,  7.00batch/s, train_loss=0.0345]


epoch=7, loss=0.6099536052243463


100%|██████████| 18/18 [00:02<00:00,  8.12batch/s, train_loss=0.0337]


epoch=8, loss=0.5401770396068178


100%|██████████| 18/18 [00:02<00:00,  6.44batch/s, train_loss=0.0268]


epoch=9, loss=0.465124990068633


100%|██████████| 18/18 [00:02<00:00,  8.15batch/s, train_loss=0.0238]


epoch=10, loss=0.39707756221705465


100%|██████████| 18/18 [00:02<00:00,  7.65batch/s, train_loss=0.0189]


epoch=11, loss=0.3369934104064415


100%|██████████| 18/18 [00:02<00:00,  7.98batch/s, train_loss=0.018] 


epoch=12, loss=0.28484913717467214


100%|██████████| 18/18 [00:02<00:00,  7.66batch/s, train_loss=0.0142]


epoch=13, loss=0.2367570889242764


100%|██████████| 18/18 [00:02<00:00,  7.94batch/s, train_loss=0.0112] 


epoch=14, loss=0.19707339739799504


100%|██████████| 18/18 [00:02<00:00,  8.49batch/s, train_loss=0.0081] 


epoch=15, loss=0.16647859665854228


100%|██████████| 18/18 [00:02<00:00,  7.52batch/s, train_loss=0.0086] 


epoch=16, loss=0.14269362223559415


100%|██████████| 18/18 [00:02<00:00,  7.44batch/s, train_loss=0.00774]


epoch=17, loss=0.12408277347581141


100%|██████████| 18/18 [00:02<00:00,  8.11batch/s, train_loss=0.00672]


epoch=18, loss=0.10773165455768847


100%|██████████| 18/18 [00:02<00:00,  8.35batch/s, train_loss=0.00493]


epoch=19, loss=0.09449487268513648


100%|██████████| 18/18 [00:02<00:00,  8.07batch/s, train_loss=0.00471]


epoch=20, loss=0.08262281827679999


100%|██████████| 18/18 [00:02<00:00,  7.94batch/s, train_loss=0.00345]


epoch=21, loss=0.0747388231261023


100%|██████████| 18/18 [00:02<00:00,  7.96batch/s, train_loss=0.00439]


epoch=22, loss=0.06815497736273142


100%|██████████| 18/18 [00:02<00:00,  8.23batch/s, train_loss=0.00359]


epoch=23, loss=0.061644526311035806


100%|██████████| 18/18 [00:02<00:00,  7.13batch/s, train_loss=0.00399]


epoch=24, loss=0.05865883575431231


100%|██████████| 18/18 [00:02<00:00,  8.42batch/s, train_loss=0.00376]


epoch=25, loss=0.05315861481222613


100%|██████████| 18/18 [00:02<00:00,  8.02batch/s, train_loss=0.00262]


epoch=26, loss=0.04857661248486617


100%|██████████| 18/18 [00:02<00:00,  8.23batch/s, train_loss=0.00269]


epoch=27, loss=0.0442258097632178


100%|██████████| 18/18 [00:02<00:00,  7.47batch/s, train_loss=0.00312]


epoch=28, loss=0.04180865992759836


100%|██████████| 18/18 [00:02<00:00,  8.24batch/s, train_loss=0.00263]


epoch=29, loss=0.038094673832942696


100%|██████████| 18/18 [00:02<00:00,  8.46batch/s, train_loss=0.00203]


epoch=30, loss=0.03601468171127911


100%|██████████| 18/18 [00:02<00:00,  8.50batch/s, train_loss=0.00268]


epoch=31, loss=0.033708528654328704


100%|██████████| 18/18 [00:02<00:00,  8.25batch/s, train_loss=0.00216]


epoch=32, loss=0.03128287810703804


100%|██████████| 18/18 [00:02<00:00,  8.75batch/s, train_loss=0.00208]


epoch=33, loss=0.030103400101949427


100%|██████████| 18/18 [00:02<00:00,  8.11batch/s, train_loss=0.00208]


epoch=34, loss=0.0282550440040128


100%|██████████| 18/18 [00:02<00:00,  8.76batch/s, train_loss=0.00159]


epoch=35, loss=0.027456199888525337


100%|██████████| 18/18 [00:02<00:00,  8.28batch/s, train_loss=0.00258] 


epoch=36, loss=0.026748509232340185


100%|██████████| 18/18 [00:02<00:00,  8.55batch/s, train_loss=0.00156]


epoch=37, loss=0.025094939005785973


100%|██████████| 18/18 [00:02<00:00,  8.59batch/s, train_loss=0.00167]


epoch=38, loss=0.024662357695143803


100%|██████████| 18/18 [00:02<00:00,  8.55batch/s, train_loss=0.00108]


epoch=39, loss=0.02333713969382746


100%|██████████| 18/18 [00:02<00:00,  8.54batch/s, train_loss=0.000753]


epoch=40, loss=0.02249510441617719


100%|██████████| 18/18 [00:02<00:00,  8.52batch/s, train_loss=0.00113] 


epoch=41, loss=0.02212891823772727


100%|██████████| 18/18 [00:02<00:00,  8.33batch/s, train_loss=0.000569]


epoch=42, loss=0.02166013054549694


100%|██████████| 18/18 [00:02<00:00,  8.82batch/s, train_loss=0.00225] 


epoch=43, loss=0.021153650690769327


100%|██████████| 18/18 [00:02<00:00,  8.30batch/s, train_loss=0.000531]


epoch=44, loss=0.02008879695961188


100%|██████████| 18/18 [00:02<00:00,  8.57batch/s, train_loss=0.00201] 


epoch=45, loss=0.020329654604196546


100%|██████████| 18/18 [00:02<00:00,  8.54batch/s, train_loss=0.0014]  


epoch=46, loss=0.019665250862466878


100%|██████████| 18/18 [00:02<00:00,  8.61batch/s, train_loss=0.00106] 


epoch=47, loss=0.019464602054706934


100%|██████████| 18/18 [00:02<00:00,  8.57batch/s, train_loss=0.00129] 


epoch=48, loss=0.019247237066770423


100%|██████████| 18/18 [00:02<00:00,  8.60batch/s, train_loss=0.00122] 


epoch=49, loss=0.018802558920506772


100%|██████████| 18/18 [00:02<00:00,  8.29batch/s, train_loss=0.00061] 


epoch=50, loss=0.018734573096550744


100%|██████████| 18/18 [00:02<00:00,  8.25batch/s, train_loss=0.00122] 


epoch=51, loss=0.018821179524577895


100%|██████████| 18/18 [00:02<00:00,  8.10batch/s, train_loss=0.00102] 


epoch=52, loss=0.01927803964707358


100%|██████████| 18/18 [00:02<00:00,  8.86batch/s, train_loss=0.00175] 


epoch=53, loss=0.01889551408434736


100%|██████████| 18/18 [00:02<00:00,  8.29batch/s, train_loss=0.00122] 


epoch=54, loss=0.01863418610548151


100%|██████████| 18/18 [00:02<00:00,  8.59batch/s, train_loss=0.0013]  


epoch=55, loss=0.018258316785097123


100%|██████████| 18/18 [00:02<00:00,  8.35batch/s, train_loss=0.00121] 


epoch=56, loss=0.01832774330829752


100%|██████████| 18/18 [00:02<00:00,  8.54batch/s, train_loss=0.00108] 


epoch=57, loss=0.018596900835633277


100%|██████████| 18/18 [00:02<00:00,  8.52batch/s, train_loss=0.000773]


epoch=58, loss=0.017890421113577384


100%|██████████| 18/18 [00:02<00:00,  7.95batch/s, train_loss=0.00107] 


epoch=59, loss=0.0186101315689498


100%|██████████| 18/18 [00:02<00:00,  6.88batch/s, train_loss=0.00143] 


epoch=60, loss=0.018569117518334555


100%|██████████| 18/18 [00:02<00:00,  8.60batch/s, train_loss=0.00167] 


epoch=61, loss=0.018221392220464247


100%|██████████| 18/18 [00:02<00:00,  8.13batch/s, train_loss=0.000656]


epoch=62, loss=0.017910856486394486


100%|██████████| 18/18 [00:02<00:00,  8.35batch/s, train_loss=0.000928]


epoch=63, loss=0.01805274901873079


100%|██████████| 18/18 [00:02<00:00,  8.41batch/s, train_loss=0.000968]


epoch=64, loss=0.01742546550929546


100%|██████████| 18/18 [00:02<00:00,  8.57batch/s, train_loss=0.00103] 


epoch=65, loss=0.017722615987062453


100%|██████████| 18/18 [00:02<00:00,  8.49batch/s, train_loss=0.00137] 


epoch=66, loss=0.01815839174389839


100%|██████████| 18/18 [00:02<00:00,  8.45batch/s, train_loss=0.000874]


epoch=67, loss=0.017595676500221775


100%|██████████| 18/18 [00:02<00:00,  8.24batch/s, train_loss=0.00097] 


epoch=68, loss=0.017502988620564854


100%|██████████| 18/18 [00:02<00:00,  8.18batch/s, train_loss=0.000851]


epoch=69, loss=0.01787270340272065


100%|██████████| 18/18 [00:02<00:00,  8.24batch/s, train_loss=0.00161] 


epoch=70, loss=0.018084863539399774


100%|██████████| 18/18 [00:02<00:00,  8.83batch/s, train_loss=0.000624]


epoch=71, loss=0.01726967101723983


100%|██████████| 18/18 [00:02<00:00,  8.28batch/s, train_loss=0.0008]  


epoch=72, loss=0.01785903632383922


100%|██████████| 18/18 [00:02<00:00,  8.48batch/s, train_loss=0.000634]


epoch=73, loss=0.017493314848377788


100%|██████████| 18/18 [00:02<00:00,  8.52batch/s, train_loss=0.00164] 


epoch=74, loss=0.01789117240802995


100%|██████████| 18/18 [00:02<00:00,  8.58batch/s, train_loss=0.00138] 


epoch=75, loss=0.017880446202796083


100%|██████████| 18/18 [00:02<00:00,  8.56batch/s, train_loss=0.000854]


epoch=76, loss=0.017870749874361627


100%|██████████| 18/18 [00:02<00:00,  8.53batch/s, train_loss=0.000916]


epoch=77, loss=0.017652892524826113


100%|██████████| 18/18 [00:02<00:00,  8.19batch/s, train_loss=0.00105] 


epoch=78, loss=0.01804624393736494


100%|██████████| 18/18 [00:02<00:00,  8.80batch/s, train_loss=0.000847]


epoch=79, loss=0.017604830337495637


100%|██████████| 18/18 [00:02<00:00,  7.31batch/s, train_loss=0.00123] 


epoch=80, loss=0.017849768241931654


100%|██████████| 18/18 [00:02<00:00,  8.37batch/s, train_loss=0.00106] 


epoch=81, loss=0.018163895997507815


100%|██████████| 18/18 [00:02<00:00,  8.37batch/s, train_loss=0.00129] 


epoch=82, loss=0.017754320844494062


100%|██████████| 18/18 [00:02<00:00,  8.38batch/s, train_loss=0.00101] 


epoch=83, loss=0.017799366928380116


100%|██████████| 18/18 [00:02<00:00,  8.36batch/s, train_loss=0.0014]  


epoch=84, loss=0.01810576706610877


100%|██████████| 18/18 [00:02<00:00,  7.88batch/s, train_loss=0.000983]


epoch=85, loss=0.017505382322031877


100%|██████████| 18/18 [00:02<00:00,  7.93batch/s, train_loss=0.000951]


epoch=86, loss=0.017942920198214463


100%|██████████| 18/18 [00:02<00:00,  8.43batch/s, train_loss=0.00111] 


epoch=87, loss=0.018191414475440978


100%|██████████| 18/18 [00:02<00:00,  7.97batch/s, train_loss=0.00107] 


epoch=88, loss=0.01806292178045059


100%|██████████| 18/18 [00:02<00:00,  8.45batch/s, train_loss=0.00121] 


epoch=89, loss=0.017899172603056346


100%|██████████| 18/18 [00:02<00:00,  8.06batch/s, train_loss=0.00162] 


epoch=90, loss=0.018002296043881053


100%|██████████| 18/18 [00:02<00:00,  7.92batch/s, train_loss=0.000979]


epoch=91, loss=0.017901493884366133


100%|██████████| 18/18 [00:02<00:00,  7.90batch/s, train_loss=0.000802]


epoch=92, loss=0.01732736950855831


100%|██████████| 18/18 [00:02<00:00,  7.61batch/s, train_loss=0.000929]


epoch=93, loss=0.01754939649279775


100%|██████████| 18/18 [00:02<00:00,  7.99batch/s, train_loss=0.0014]  


epoch=94, loss=0.01769660920418542


100%|██████████| 18/18 [00:02<00:00,  7.63batch/s, train_loss=0.000777]


epoch=95, loss=0.01789583602753179


100%|██████████| 18/18 [00:02<00:00,  7.58batch/s, train_loss=0.000837]


epoch=96, loss=0.01734624614232573


100%|██████████| 18/18 [00:02<00:00,  7.47batch/s, train_loss=0.00156] 


epoch=97, loss=0.018189402687138527


100%|██████████| 18/18 [00:02<00:00,  7.41batch/s, train_loss=0.00124] 


epoch=98, loss=0.01783483492608728


100%|██████████| 18/18 [00:02<00:00,  6.90batch/s, train_loss=0.00111] 


epoch=99, loss=0.017989741171228476


100%|██████████| 18/18 [00:02<00:00,  8.06batch/s, train_loss=0.00123] 

epoch=100, loss=0.017744535649644917





In [15]:
arousal_dataframe = original_df[["user_id", "item_id", "Arousal"]].copy()
arousal_dataframe.columns = ["user_id", "item_id", "rating"]

data_encoder = DataEncoder(original_df=arousal_dataframe)
data_processor = DataProcessor(original_df=arousal_dataframe)

n_users = arousal_dataframe.user_id.nunique()
n_items = arousal_dataframe.item_id.nunique()

min_rating = min(arousal_dataframe.rating.values)
max_rating = max(arousal_dataframe.rating.values)

arousal_model = HistogramMF(
    n_users=n_users,
    n_items=n_items,
    data_encoder=data_encoder,
    data_processor=data_processor,
    min_rating=min_rating,
    max_rating=max_rating,
)

if os.path.exists(f"{MODELS_DIR}/DEAM/arousal.pt"):
    valence_model.load_state_dict(torch.load(f"{MODELS_DIR}/DEAM/arousal.pt"))
else:
    epochs = 100

    criterion = Loss()
    optimizer = SGD(arousal_model.parameters(), lr=5, weight_decay=1e-7)
    runner = Runner(model=arousal_model, criterion=criterion, optimizer=optimizer)

    train_set = create_dataset(data_encoder=data_encoder)
    train_load = DataLoader(train_set, batch_size=1000, shuffle=True)

    with SummaryWriter(f"runs/DEAM/dev/arousal") as writer:
        for epoch in range(epochs):
            epoch_loss = runner.train(train_loader=train_load, epoch=epoch, writer=writer)
            print(f"epoch={epoch + 1}, loss={epoch_loss}")

    torch.save(arousal_model.state_dict(), f"{MODELS_DIR}/DEAM/arousal.pt")

  0%|          | 0/17464 [00:00<?, ?it/s]

  0%|          | 0/17464 [00:00<?, ?it/s]

100%|██████████| 18/18 [00:04<00:00,  4.12batch/s, train_loss=0.0801]


epoch=1, loss=1.5037995927744898


100%|██████████| 18/18 [00:03<00:00,  5.46batch/s, train_loss=0.0668]


epoch=2, loss=1.1652154498593559


100%|██████████| 18/18 [00:03<00:00,  5.97batch/s, train_loss=0.0525]


epoch=3, loss=0.9997535807510902


100%|██████████| 18/18 [00:03<00:00,  5.55batch/s, train_loss=0.048] 


epoch=4, loss=0.8546096723161896


100%|██████████| 18/18 [00:03<00:00,  4.71batch/s, train_loss=0.0431]


epoch=5, loss=0.7660555310742608


100%|██████████| 18/18 [00:02<00:00,  6.22batch/s, train_loss=0.0386]


epoch=6, loss=0.6864886585761761


100%|██████████| 18/18 [00:02<00:00,  7.32batch/s, train_loss=0.0359]


epoch=7, loss=0.6110562211069567


100%|██████████| 18/18 [00:03<00:00,  5.35batch/s, train_loss=0.0298]


epoch=8, loss=0.5375913905439706


100%|██████████| 18/18 [00:02<00:00,  6.89batch/s, train_loss=0.0277]


epoch=9, loss=0.46918069875651397


100%|██████████| 18/18 [00:03<00:00,  5.74batch/s, train_loss=0.0229]


epoch=10, loss=0.4026881200856176


100%|██████████| 18/18 [00:03<00:00,  5.76batch/s, train_loss=0.0209]


epoch=11, loss=0.33745164320386695


100%|██████████| 18/18 [00:02<00:00,  6.27batch/s, train_loss=0.017] 


epoch=12, loss=0.2758408615835782


100%|██████████| 18/18 [00:03<00:00,  5.03batch/s, train_loss=0.0156]


epoch=13, loss=0.22854371462197143


100%|██████████| 18/18 [00:03<00:00,  5.66batch/s, train_loss=0.011]  


epoch=14, loss=0.18485278474051378


100%|██████████| 18/18 [00:03<00:00,  5.19batch/s, train_loss=0.00805]


epoch=15, loss=0.15310113200648076


100%|██████████| 18/18 [00:02<00:00,  6.54batch/s, train_loss=0.00763]


epoch=16, loss=0.13195215552428674


100%|██████████| 18/18 [00:02<00:00,  7.09batch/s, train_loss=0.00599]


epoch=17, loss=0.11382324977578787


100%|██████████| 18/18 [00:02<00:00,  6.42batch/s, train_loss=0.00566]


epoch=18, loss=0.09871589492929393


100%|██████████| 18/18 [00:04<00:00,  4.38batch/s, train_loss=0.00569]


epoch=19, loss=0.08650066585376345


100%|██████████| 18/18 [00:03<00:00,  5.39batch/s, train_loss=0.00481]


epoch=20, loss=0.07806641987685499


100%|██████████| 18/18 [00:03<00:00,  4.84batch/s, train_loss=0.00431]


epoch=21, loss=0.06977609058495227


100%|██████████| 18/18 [00:03<00:00,  4.61batch/s, train_loss=0.00303]


epoch=22, loss=0.06283540706387883


100%|██████████| 18/18 [00:03<00:00,  5.06batch/s, train_loss=0.00268]


epoch=23, loss=0.05623705873201632


100%|██████████| 18/18 [00:03<00:00,  5.19batch/s, train_loss=0.00248]


epoch=24, loss=0.05117567556685414


100%|██████████| 18/18 [00:04<00:00,  3.96batch/s, train_loss=0.00338]


epoch=25, loss=0.047879296064376825


100%|██████████| 18/18 [00:04<00:00,  4.29batch/s, train_loss=0.00223]


epoch=26, loss=0.043436659730713936


100%|██████████| 18/18 [00:03<00:00,  5.04batch/s, train_loss=0.00195]


epoch=27, loss=0.04021648372658368


100%|██████████| 18/18 [00:02<00:00,  6.77batch/s, train_loss=0.00163]


epoch=28, loss=0.037503840307737214


100%|██████████| 18/18 [00:02<00:00,  7.59batch/s, train_loss=0.00207]


epoch=29, loss=0.034777061215762434


100%|██████████| 18/18 [00:02<00:00,  6.18batch/s, train_loss=0.00195]


epoch=30, loss=0.03208540445052344


100%|██████████| 18/18 [00:02<00:00,  7.06batch/s, train_loss=0.00161]


epoch=31, loss=0.02968057083775257


100%|██████████| 18/18 [00:02<00:00,  7.64batch/s, train_loss=0.00163]


epoch=32, loss=0.027896288398010973


100%|██████████| 18/18 [00:02<00:00,  7.59batch/s, train_loss=0.00166]


epoch=33, loss=0.02565664670385164


100%|██████████| 18/18 [00:02<00:00,  6.38batch/s, train_loss=0.00124] 


epoch=34, loss=0.024325169439973506


100%|██████████| 18/18 [00:02<00:00,  7.66batch/s, train_loss=0.0017]  


epoch=35, loss=0.023789699544166695


100%|██████████| 18/18 [00:02<00:00,  7.78batch/s, train_loss=0.00143]


epoch=36, loss=0.02269037079913863


100%|██████████| 18/18 [00:02<00:00,  6.93batch/s, train_loss=0.00151]


epoch=37, loss=0.021831473981512004


100%|██████████| 18/18 [00:02<00:00,  6.93batch/s, train_loss=0.000927]


epoch=38, loss=0.020414701854360513


100%|██████████| 18/18 [00:02<00:00,  7.96batch/s, train_loss=0.00138] 


epoch=39, loss=0.02067336857216111


100%|██████████| 18/18 [00:02<00:00,  7.55batch/s, train_loss=0.000518]


epoch=40, loss=0.019780174778196317


100%|██████████| 18/18 [00:02<00:00,  6.51batch/s, train_loss=0.00143] 


epoch=41, loss=0.019604185859704838


100%|██████████| 18/18 [00:02<00:00,  7.38batch/s, train_loss=0.0013]  


epoch=42, loss=0.01875696748700635


100%|██████████| 18/18 [00:02<00:00,  7.89batch/s, train_loss=0.00133] 


epoch=43, loss=0.01846582660798369


100%|██████████| 18/18 [00:02<00:00,  7.05batch/s, train_loss=0.000825]


epoch=44, loss=0.017890583413428272


100%|██████████| 18/18 [00:02<00:00,  6.94batch/s, train_loss=0.000727]


epoch=45, loss=0.017710523440920074


100%|██████████| 18/18 [00:02<00:00,  7.54batch/s, train_loss=0.00101] 


epoch=46, loss=0.0174130333970333


100%|██████████| 18/18 [00:02<00:00,  8.15batch/s, train_loss=0.00132] 


epoch=47, loss=0.01726701797904639


100%|██████████| 18/18 [00:02<00:00,  7.63batch/s, train_loss=0.000881]


epoch=48, loss=0.016361025997276964


100%|██████████| 18/18 [00:02<00:00,  6.74batch/s, train_loss=0.00114] 


epoch=49, loss=0.01699291652852091


100%|██████████| 18/18 [00:02<00:00,  7.19batch/s, train_loss=0.000969]


epoch=50, loss=0.01661421840355314


100%|██████████| 18/18 [00:03<00:00,  5.97batch/s, train_loss=0.00099] 


epoch=51, loss=0.016452090427793307


100%|██████████| 18/18 [00:02<00:00,  6.01batch/s, train_loss=0.00089] 


epoch=52, loss=0.016501363487593058


100%|██████████| 18/18 [00:02<00:00,  6.01batch/s, train_loss=0.00145] 


epoch=53, loss=0.01676527734768802


100%|██████████| 18/18 [00:02<00:00,  7.03batch/s, train_loss=0.0012]  


epoch=54, loss=0.01633981393431795


100%|██████████| 18/18 [00:02<00:00,  7.06batch/s, train_loss=0.000723]


epoch=55, loss=0.016077371873218442


100%|██████████| 18/18 [00:03<00:00,  5.82batch/s, train_loss=0.000977]


epoch=56, loss=0.01605913810534724


100%|██████████| 18/18 [00:02<00:00,  6.25batch/s, train_loss=0.000886]


epoch=57, loss=0.01608221905806969


100%|██████████| 18/18 [00:03<00:00,  5.92batch/s, train_loss=0.000668]


epoch=58, loss=0.016260621983943315


100%|██████████| 18/18 [00:03<00:00,  5.97batch/s, train_loss=0.000877]


epoch=59, loss=0.01603781113028526


100%|██████████| 18/18 [00:02<00:00,  6.17batch/s, train_loss=0.000869]


epoch=60, loss=0.01607453647975264


100%|██████████| 18/18 [00:02<00:00,  6.12batch/s, train_loss=0.00111] 


epoch=61, loss=0.015870530752272443


100%|██████████| 18/18 [00:03<00:00,  5.94batch/s, train_loss=0.00102] 


epoch=62, loss=0.015638239258836055


100%|██████████| 18/18 [00:02<00:00,  6.25batch/s, train_loss=0.000789]


epoch=63, loss=0.015567670944949676


100%|██████████| 18/18 [00:02<00:00,  6.00batch/s, train_loss=0.000666]


epoch=64, loss=0.015570746343197496


100%|██████████| 18/18 [00:02<00:00,  6.66batch/s, train_loss=0.00113] 


epoch=65, loss=0.015735292718328277


100%|██████████| 18/18 [00:03<00:00,  5.41batch/s, train_loss=0.00118] 


epoch=66, loss=0.015867024936552704


100%|██████████| 18/18 [00:03<00:00,  5.85batch/s, train_loss=0.000715]


epoch=67, loss=0.015372225512204499


100%|██████████| 18/18 [00:02<00:00,  6.87batch/s, train_loss=0.000556]


epoch=68, loss=0.015294238537549973


100%|██████████| 18/18 [00:02<00:00,  6.60batch/s, train_loss=0.000589]


epoch=69, loss=0.015384434814083166


100%|██████████| 18/18 [00:03<00:00,  5.55batch/s, train_loss=0.000709]


epoch=70, loss=0.015371211808303307


100%|██████████| 18/18 [00:03<00:00,  5.03batch/s, train_loss=0.000726]


epoch=71, loss=0.01533691788336326


100%|██████████| 18/18 [00:03<00:00,  5.09batch/s, train_loss=0.00107] 


epoch=72, loss=0.01541091866945398


100%|██████████| 18/18 [00:03<00:00,  5.18batch/s, train_loss=0.000902]


epoch=73, loss=0.015473104549893018


100%|██████████| 18/18 [00:03<00:00,  4.73batch/s, train_loss=0.000569]


epoch=74, loss=0.015366551707530844


100%|██████████| 18/18 [00:02<00:00,  6.84batch/s, train_loss=0.000798]


epoch=75, loss=0.01525854370152128


100%|██████████| 18/18 [00:02<00:00,  7.19batch/s, train_loss=0.000589]


epoch=76, loss=0.015266169667757791


100%|██████████| 18/18 [00:02<00:00,  6.73batch/s, train_loss=0.00133] 


epoch=77, loss=0.015717470547248577


100%|██████████| 18/18 [00:03<00:00,  4.90batch/s, train_loss=0.000715]


epoch=78, loss=0.015464778338013024


100%|██████████| 18/18 [00:03<00:00,  5.68batch/s, train_loss=0.000649]


epoch=79, loss=0.015348030764994951


100%|██████████| 18/18 [00:02<00:00,  6.82batch/s, train_loss=0.000677]


epoch=80, loss=0.015293531322787546


100%|██████████| 18/18 [00:02<00:00,  6.63batch/s, train_loss=0.000939]


epoch=81, loss=0.015347868825341096


100%|██████████| 18/18 [00:02<00:00,  6.97batch/s, train_loss=0.00121] 


epoch=82, loss=0.01546753212398496


100%|██████████| 18/18 [00:02<00:00,  6.94batch/s, train_loss=0.00118] 


epoch=83, loss=0.015436883943861928


100%|██████████| 18/18 [00:02<00:00,  6.60batch/s, train_loss=0.000909]


epoch=84, loss=0.01542608105211422


100%|██████████| 18/18 [00:02<00:00,  7.69batch/s, train_loss=0.000731]


epoch=85, loss=0.015251995435562625


100%|██████████| 18/18 [00:02<00:00,  6.24batch/s, train_loss=0.000871]


epoch=86, loss=0.015480662390589713


100%|██████████| 18/18 [00:02<00:00,  6.78batch/s, train_loss=0.000901]


epoch=87, loss=0.015448431561733115


100%|██████████| 18/18 [00:03<00:00,  5.74batch/s, train_loss=0.000816]


epoch=88, loss=0.015335215225815773


100%|██████████| 18/18 [00:02<00:00,  6.93batch/s, train_loss=0.000461]


epoch=89, loss=0.015132976699492026


100%|██████████| 18/18 [00:02<00:00,  6.52batch/s, train_loss=0.000741]


epoch=90, loss=0.015203204471489479


100%|██████████| 18/18 [00:03<00:00,  5.66batch/s, train_loss=0.0014]  


epoch=91, loss=0.015531938305188868


100%|██████████| 18/18 [00:02<00:00,  7.38batch/s, train_loss=0.00103] 


epoch=92, loss=0.015582470503860508


100%|██████████| 18/18 [00:02<00:00,  7.06batch/s, train_loss=0.000952]


epoch=93, loss=0.015338768607069706


100%|██████████| 18/18 [00:02<00:00,  7.09batch/s, train_loss=0.000633]


epoch=94, loss=0.015322789928522606


100%|██████████| 18/18 [00:02<00:00,  6.59batch/s, train_loss=0.000784]


epoch=95, loss=0.015319250471119226


100%|██████████| 18/18 [00:02<00:00,  7.15batch/s, train_loss=0.000892]


epoch=96, loss=0.015267583617362482


100%|██████████| 18/18 [00:02<00:00,  6.05batch/s, train_loss=0.000924]


epoch=97, loss=0.01562575781602284


100%|██████████| 18/18 [00:03<00:00,  5.61batch/s, train_loss=0.000495]


epoch=98, loss=0.015199870771889027


100%|██████████| 18/18 [00:02<00:00,  6.88batch/s, train_loss=0.000791]


epoch=99, loss=0.015528864988479121


100%|██████████| 18/18 [00:02<00:00,  6.49batch/s, train_loss=0.00076] 

epoch=100, loss=0.015288945635331088





In [16]:
from pandas import DataFrame
from collections import namedtuple

Row = namedtuple("Row", "user_id item_id Valence Arousal")

data_encoder = DataEncoder(original_df=original_df)
dataframe_after_mf = []

with torch.no_grad():
    for (index, user_id, item_id, valence, arousal) in original_df.itertuples():
        encoded_user_id = data_encoder.get_encoded_user_id(original_id=user_id)
        encoded_item_id = data_encoder.get_encoded_item_id(original_id=item_id)

        user_id_as_tensor = torch.LongTensor([encoded_user_id])
        item_id_as_tensor = torch.LongTensor([encoded_item_id])
        valence_output = valence_model(users=user_id_as_tensor, items=item_id_as_tensor,).squeeze()[0]
        arousal_output = arousal_model(users=user_id_as_tensor, items=item_id_as_tensor,).squeeze()[0]
        predicted_valence = torch.round(valence_output).item()
        predicted_arousal = torch.round(arousal_output).item()

        dataframe_after_mf.append(Row(user_id=user_id, item_id=item_id, Valence=predicted_valence, Arousal=predicted_arousal))

df_after_mf = DataFrame(dataframe_after_mf, columns=["user_id", "item_id", "Valence", "Arousal"])
df_after_mf.head()

Unnamed: 0,user_id,item_id,Valence,Arousal
0,6010bbc8e7ef4b21fa38f9c3a9754ef3,2,5.0,2.0
1,3c888e77b992ae3cd2adfe16774e23b9,2,2.0,3.0
2,2afd218c3aecb6828d2be327f8b9c46f,2,3.0,3.0
3,fd5b08ce362d855ca9152a894348130c,2,4.0,4.0
4,9c8073214a052e414811b76012df8847,2,2.0,2.0


In [19]:
mask = ((original_df["Valence"] == df_after_mf["Valence"]) & (original_df["Arousal"] == df_after_mf["Arousal"]))
changes = original_df[mask].copy()
changes["New Valence"] = df_after_mf.Valence
changes["New Arousal"] = df_after_mf.Arousal
print(f"Number of hits: {len(changes)} / {len(original_df)}")
changes.head(len(changes))

Number of hits: 17341 / 17464


Unnamed: 0,user_id,item_id,Valence,Arousal,New Valence,New Arousal
0,6010bbc8e7ef4b21fa38f9c3a9754ef3,2,5,2,5.0,2.0
1,3c888e77b992ae3cd2adfe16774e23b9,2,2,3,2.0,3.0
2,2afd218c3aecb6828d2be327f8b9c46f,2,3,3,3.0,3.0
3,fd5b08ce362d855ca9152a894348130c,2,4,4,4.0,4.0
4,9c8073214a052e414811b76012df8847,2,2,2,2.0,2.0
...,...,...,...,...,...,...
17459,607f6e34a0b5923333f6b16d3a59cc98,2000,6,5,6.0,5.0
17460,78b5e9744073532cc376976b5fc6b2fc,2000,7,7,7.0,7.0
17461,7cecbffe1da5ae974952db6c13695afe,2000,4,5,4.0,5.0
17462,ed7ed76453bd846859f5e6b9149df276,2000,6,7,6.0,7.0
