<a href="https://colab.research.google.com/github/arampacha/fastai_nbs/blob/main/deberta_finetuning_fastai_imdb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -Uqq fastai transformers ohmeow-blurr wandb

[K     |████████████████████████████████| 194kB 9.1MB/s 
[K     |████████████████████████████████| 1.8MB 15.6MB/s 
[K     |████████████████████████████████| 61kB 9.5MB/s 
[K     |████████████████████████████████| 2.0MB 50.3MB/s 
[K     |████████████████████████████████| 61kB 9.2MB/s 
[K     |████████████████████████████████| 3.2MB 49.9MB/s 
[K     |████████████████████████████████| 890kB 52.1MB/s 
[K     |████████████████████████████████| 51kB 6.6MB/s 
[K     |████████████████████████████████| 4.1MB 48.9MB/s 
[K     |████████████████████████████████| 184kB 58.0MB/s 
[K     |████████████████████████████████| 14.5MB 226kB/s 
[K     |████████████████████████████████| 1.2MB 51.1MB/s 
[K     |████████████████████████████████| 163kB 57.6MB/s 
[K     |████████████████████████████████| 102kB 15.4MB/s 
[K     |████████████████████████████████| 133kB 56.4MB/s 
[K     |████████████████████████████████| 102kB 14.1MB/s 
[K     |████████████████████████████████| 20.7MB 1.4MB/s 
[K 

In [None]:
from fastai.text.all import *
from fastai.callback.wandb import *

In [None]:
def read_text(fn):
    return open(fn).read()

In [None]:
path = untar_data(URLs.IMDB)

In [None]:
model_name = "microsoft/deberta-base"

max_len = 512
bs = 8
val_bs = 16

## Tracking

In [None]:
!wandb login

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
import wandb

WANDB_NAME = f'imdb-{model_name}-simple'
GROUP = 'IMDB'
NOTES = f'Simple finetuning {model_name} with RAdam lr=2e-5'
CONFIG = {}
TAGS =[model_name,'imdb','radam','simple']

In [None]:
wandb.init(reinit=True, project="lm-finetuning", entity="arampacha", 
           name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG)

[34m[1mwandb[0m: Currently logged in as: [33mfastai_community[0m (use `wandb login --relogin` to force relogin)


## Training

In [None]:
from transformers import *

from blurr.data.all import *
from blurr.modeling.all import *

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


In [None]:
task = HF_TASKS_AUTO.SequenceClassification

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, task=task, tokenizer_kwargs={'max_len':512})

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=474.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3917897.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=558582766.0, style=ProgressStyle(descri…




In [None]:
blocks = (HF_TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), CategoryBlock)
dblock = DataBlock(blocks=blocks, 
                   get_items=get_text_files,
                   get_x = read_text,
                   get_y=parent_label,
                   splitter=GrandparentSplitter(valid_name='test'))

dls = dblock.dataloaders(path, bs=bs, val_bs=val_bs)

In [None]:
model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls,
                model,
                opt_func=RAdam,
                metrics=[accuracy],
                cbs=[HF_BaseModelCallback],
                splitter=hf_splitter).to_fp16()

learn.create_opt() 
learn.freeze()
learn.blurr_summary()

HF_BaseModelWrapper (Input shape: 8 x 512)
Layer (type)         Output Shape         Param #    Trainable 
                     8 x 512 x 768       
Embedding                                 38603520   False     
DebertaLayerNorm                          1536       False     
StableDropout                                                  
____________________________________________________________________________
                     8 x 512 x 2304      
Linear                                    1769472    False     
StableDropout                                                  
Linear                                    589824     False     
Linear                                    590592     False     
StableDropout                                                  
Linear                                    590592     False     
DebertaLayerNorm                          1536       False     
StableDropout                                                  
____________________________

### Simple ft

In [None]:
learn.unfreeze()
learn.fit_one_cycle(4, 2e-5, cbs=WandbCallback(log_preds=False, log_model=False))

Could not gather input dimensions


epoch,train_loss,valid_loss,accuracy,time
0,0.160276,0.164386,0.93976,29:55
1,0.132156,0.129645,0.95312,30:00
2,0.048035,0.143527,0.957,30:00
3,0.04481,0.180624,0.958,29:58


### All unfrozen

In [None]:
learn.load('stage0')
learn.unfreeze()
learn.fit_one_cycle(3, slice(2e-6, 2e-5))

epoch,train_loss,valid_loss,accuracy,time
0,0.18403,0.191095,0.92588,16:47
1,0.179359,0.179827,0.93104,16:47
2,0.12062,0.194259,0.931,16:46


### Adam

In [None]:
learn.load('stage0')
learn.opt_func = Adam

In [None]:
learn.unfreeze()
learn.fit_one_cycle(3, 2e-5)

epoch,train_loss,valid_loss,accuracy,time
