## 1. Imports

In [1]:
import pandas as pd
import numpy as np
import yaml
import copy
import torch
import json
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from sqlalchemy import create_engine
from sklearn.model_selection import train_test_split
from enum import Enum
from pytorch_datasets import SentimentAnalysisDataset, DatasetType

In [2]:
model_name = "ProsusAI/finbert"

## 2. Database configuration & Model config

In [3]:
# Reading form config.yaml"
with open("../../config.yaml", "r") as yamlconfig:
    config = yaml.load(yamlconfig, Loader=yaml.FullLoader)

# Create postgres string with db-config
postgres_username = config["db_config"]["postgres_username"]
postgres_password = config["db_config"]["postgres_password"]
postgres_address = config["db_config"]["postgres_address"]
postgres_port = config["db_config"]["postgres_port"]
postgres_dbname = config["db_config"]["postgres_dbname"]

postgres_str = f"postgresql://{postgres_username}:{postgres_password}@{postgres_address}:{postgres_port}/{postgres_dbname}"

# create db connection with sqlalchemy
cnx = create_engine(postgres_str)

In [4]:
# Load json file with hyperparams of each model
with open('hyperparams.json') as file:
    hyper_params = json.load(file)

In [5]:
# Set up Hyper parameters for model training
LR: float = hyper_params[model_name]["lr"]
OPTIMIZER: str = hyper_params[model_name]["lr"]
EPOCHS: int = hyper_params[model_name]["epochs"]
BATCH_SIZE: int = hyper_params[model_name]["batch_size"]
DROPOUT: float = hyper_params[model_name]["dropout"]

## 3. Dataframe preperations

In [6]:
df = pd.read_sql('SELECT * FROM r_wallstreetbets_stock_symbols LIMIT 100', cnx)

In [7]:
df['label'] = np.random.choice([1,2,3], df.shape[0])

In [8]:
df = df[["post", "label"]]

In [9]:
df["label"].value_counts(normalize=True)*100

2    41.0
1    33.0
3    26.0
Name: label, dtype: float64

## 4. Model Loading

In [10]:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [11]:
#model

## 3. Building Pytorch Dataset

In [12]:
# Declare generic sentiment analysis dataset without split
sentiment_analysis_dataset = SentimentAnalysisDataset(
    df = df,
    tokenizer = tokenizer
)

In [13]:
# Declare train and test dataset
train_dataset = copy.deepcopy(sentiment_analysis_dataset).set_fold(DatasetType.TRAIN)
test_dataset = copy.deepcopy(sentiment_analysis_dataset).set_fold(DatasetType.TEST)

In [21]:
# Setup train and test Data loaders
train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=BATCH_SIZE,
                                                shuffle=True,
                                                num_workers=1,
                                                drop_last=False # maybe change in future
                                                )

test_data_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=test_dataset.__len__(),
                                              shuffle=False,
                                              num_workers=1,
                                              )

In [15]:
train_dataset.__len__()

80

In [16]:
test_dataset.__len__()

20

In [25]:
# Check if train data and test data have correct batch and tensor sizes
"""print('TRAINING DATA:')
for dictionary in train_data_loader:
    print(dictionary)
    break"""

print(' ')
print('TESTING DATA:')
for dictionary in test_data_loader:
    print(dictionary["labels"].size())
    #break

 
TESTING DATA:
torch.Size([20, 1])


## WANDB test

In [26]:
import wandb

wandb.init(project="test-project", entity="hda_sis")

[34m[1mwandb[0m: Currently logged in as: [33mjan_burger[0m ([33mhda_sis[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [27]:
wandb.config = {
    "learning_rate": 0.001,
    "epochs": 100,
    "batch_size": 128
}

In [29]:
wandb.log({"loss": 2.5})