# TRAIN GENERATIVE MODEL

In [1]:
from youper import get_data, toggle_model_freeze, print_frozenness, ReflectionModel

from transformers import RobertaModel, RobertaTokenizer
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
import torch

In [2]:
%load_ext autoreload
%load_ext tensorboard
%autoreload 2

## LOAD DATA

In [3]:
df = get_data()
df.shape

(2271, 14)

In [4]:
df.head()

Unnamed: 0,questionID,questionTitle,questionText,questionLink,topic,therapistInfo,therapistURL,answerText,upvotes,views,split,first_an_sent,seems_sounds_sents,reflection
0,0,Can I change my feeling of being worthless to ...,I'm going through some things with my feelings...,https://counselchat.com/questions/can-i-change...,depression,"Sherry Katz, LCSWCouples and Family Therapist,...",https://counselchat.com/therapists/sherry-katz...,"If everyone thinks you're worthless, then mayb...",1,2899,train,"If everyone thinks you're worthless, then mayb...",,"If everyone thinks you're worthless, then mayb..."
1,0,Can I change my feeling of being worthless to ...,I'm going through some things with my feelings...,https://counselchat.com/questions/can-i-change...,depression,"Robin Landwehr, DBH, LPCC, NCCMental Health in...",https://counselchat.com/therapists/robin-landw...,"Hello, and thank you for your question and see...",1,3514,train,"Hello, and thank you for your question and see...",,"Hello, and thank you for your question and see..."
2,0,Can I change my feeling of being worthless to ...,I'm going through some things with my feelings...,https://counselchat.com/questions/can-i-change...,depression,Lee KingI use an integrative approach to treat...,https://counselchat.com/therapists/lee-king,First thing I'd suggest is getting the sleep y...,0,5,train,First thing I'd suggest is getting the sleep y...,,First thing I'd suggest is getting the sleep y...
3,0,Can I change my feeling of being worthless to ...,I'm going through some things with my feelings...,https://counselchat.com/questions/can-i-change...,depression,"Shauntai Davis-YearginPersonalized, private on...",https://counselchat.com/therapists/shauntai-da...,Therapy is essential for those that are feelin...,0,31,train,Therapy is essential for those that are feelin...,,Therapy is essential for those that are feelin...
4,0,Can I change my feeling of being worthless to ...,I'm going through some things with my feelings...,https://counselchat.com/questions/can-i-change...,depression,Jordan WhiteLicensed Social Worker at Oak Root...,https://counselchat.com/therapists/jordan-white,I first want to let you know that you are not ...,0,620,train,I first want to let you know that you are not ...,,I first want to let you know that you are not ...


## Train / Valid / Test Split

In [5]:
train_mask = df.split == 'train' 
valid_mask = df.split == 'val'
test_mask = df.split == 'test'
train_mask.sum(), valid_mask.sum(), test_mask.sum()

(1963, 185, 123)

## Pretrained RoBERTA encoder

In [6]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

roberta = RobertaModel.from_pretrained('roberta-base')

toggle_model_freeze(roberta, frozen=True)
print_frozenness(roberta)

199 / 199 (100.000)% parameters are frozen


## Reflection model

In [7]:
addl_state = {
    'df': df,
    'train_mask': train_mask,
    'valid_mask': valid_mask,
    'tokenizer': tokenizer}

model = ReflectionModel(
    roberta=roberta, 
    dec_hidden_sz=512,
    addl_state=addl_state)

## Train

In [8]:
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode='min'
)
trainer = Trainer(early_stop_callback=early_stop_callback, gpus=1)

In [9]:
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6007 (pid 11428), started 4:56:24 ago. (Use '!kill 11428' to kill it.)

In [10]:
trainer.fit(model)

HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=24.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=24.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=24.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=24.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=24.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=24.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=24.0, style=Pro…




1

In [11]:
torch.save(model.state_dict(), 'models/reflections.pt')

In [12]:
sum(p.numel() for p in model.dec.lstm.parameters())

2625536