# Preparation stuff

## Connect to Drive

In [None]:
connect_to_drive = False

In [None]:
#Run command and authorize by popup --> other window
if connect_to_drive:
    from google.colab import drive
    drive.mount('/content/gdrive', force_remount=True)

## Install packages

In [None]:
if connect_to_drive:
    #Install FS code
    !pip install  --upgrade --force-reinstall git+https://github.com/federicosiciliano/easy_lightning.git

    !pip install pytorch_lightning

## IMPORTS

In [None]:
#Put all imports here
import numpy as np
from copy import deepcopy
import os
import sys

## Define paths

In [None]:
#every path should start from the project folder:
project_folder = "../"
if connect_to_drive:
    project_folder = "/content/gdrive/Shareddrives/<SharedDriveName>" #Name of SharedDrive folder
    #project_folder = "/content/gdrive/MyDrive/<MyDriveName>" #Name of MyDrive folder

#Config folder should contain hyperparameters configurations
cfg_folder = os.path.join(project_folder,"cfg")

#Data folder should contain raw and preprocessed data
data_folder = os.path.join(project_folder,"data")
raw_data_folder = os.path.join(data_folder,"raw")
processed_data_folder = os.path.join(data_folder,"processed")

#Source folder should contain all the (essential) source code
source_folder = os.path.join(project_folder,"src")

#The out folder should contain all outputs: models, results, plots, etc.
out_folder = os.path.join(project_folder,"out")
img_folder = os.path.join(out_folder,"img")

## Import own code

In [None]:
#To import from src:

#attach the source folder to the start of sys.path
sys.path.insert(0, project_folder)

#import from src directory
# from src import ??? as additional_module
from easy_lightning import easy_rec as additional_module #REMOVE THIS LINE IF IMPORTING OWN ADDITIONAL MODULE

from easy_lightning import easy_exp, easy_rec, easy_torch #easy_data

In [None]:
from datasets import load_dataset

if not os.path.exists(os.path.join(data_folder, "raw", "tim")):
    # download the dataset
    dataset = load_dataset("shba93/tim-rec", split="train")
    # save it to the raw data folder as csv
    os.makedirs(os.path.join(data_folder, "raw", "tim"))
    dataset.to_csv(os.path.join(raw_data_folder, "tim", "dataset.csv"), index=False)
    #dataset = load_dataset("shba93/tim-rec")["train"].to_pandas()

# MAIN

## Train

### Data

In [None]:
cfg = easy_exp.cfg.load_configuration("config_rec")

In [None]:
exp_found, experiment_id = easy_exp.exp.get_set_experiment_id(cfg)

In [None]:
data, maps = easy_rec.preparation.prepare_rec_data(cfg)

In [None]:
for split_name in ["train", "val", "test"]:
    data[f"{split_name}_rating"] = [[(y-0.5)*2 for y in x] for x in data[f"{split_name}_rating"]]

In [None]:
loaders = easy_rec.preparation.prepare_rec_dataloaders(cfg, data, maps)

In [None]:
main_module = easy_rec.preparation.prepare_rec_model(cfg, maps)

In [None]:
trainer = easy_torch.preparation.complete_prepare_trainer(cfg, experiment_id, additional_module=easy_rec)

In [None]:
model = easy_torch.preparation.complete_prepare_model(cfg, main_module, easy_rec)

In [None]:
experiment_id

In [None]:
easy_torch.process.test_model(trainer, model, loaders, test_key=["val","test","train"])

In [None]:
# Train the model using the prepared trainer, model, and data loaders
easy_torch.process.train_model(trainer, model, loaders, val_key=["val","test"])

# Early stopping with Tune schedulers may not run anything after training

In [None]:
easy_torch.process.test_model(trainer, model, loaders, test_key=["val","test","train"])

In [None]:
#print("Experiment already found:", exp_found, "----> The experiment id is:", experiment_id)

# if exp_found and if_exp_found == "skip":
#     #print("Skipping experiment")
#     return

# Save experiment (done here cause Early stopping with Tune schedulers may not run anything after training)
easy_exp.exp.save_experiment(cfg)