<a href="https://colab.research.google.com/github/ElFosco/NLP_argument_creation/blob/main/XLNet_STS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup
Install dependencies

In [1]:
! pip install sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[?25l[K     |▎                               | 10 kB 19.9 MB/s eta 0:00:01[K     |▌                               | 20 kB 9.7 MB/s eta 0:00:01[K     |▉                               | 30 kB 8.1 MB/s eta 0:00:01[K     |█                               | 40 kB 7.5 MB/s eta 0:00:01[K     |█▍                              | 51 kB 5.2 MB/s eta 0:00:01[K     |█▋                              | 61 kB 5.3 MB/s eta 0:00:01[K     |██                              | 71 kB 5.5 MB/s eta 0:00:01[K     |██▏                             | 81 kB 6.2 MB/s eta 0:00:01[K     |██▍                             | 92 kB 4.9 MB/s eta 0:00:01[K     |██▊                             | 102 kB 5.4 MB/s eta 0:00:01[K     |███                             | 112 kB 5.4 MB/s eta 0:00:01[K     |███▎                            | 122 kB 5.4 MB/s eta 0:00:01[K     |███▌         

Download the pretrained XLNet model and unzip

In [2]:
# only needs to be done once
! wget https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip
! unzip cased_L-24_H-1024_A-16.zip 

--2022-01-31 09:37:30--  https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.6.128, 142.250.159.128, 74.125.201.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.6.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1338042341 (1.2G) [application/zip]
Saving to: ‘cased_L-24_H-1024_A-16.zip’


2022-01-31 09:37:49 (68.2 MB/s) - ‘cased_L-24_H-1024_A-16.zip’ saved [1338042341/1338042341]

Archive:  cased_L-24_H-1024_A-16.zip
   creating: xlnet_cased_L-24_H-1024_A-16/
  inflating: xlnet_cased_L-24_H-1024_A-16/xlnet_model.ckpt.index  
  inflating: xlnet_cased_L-24_H-1024_A-16/xlnet_model.ckpt.data-00000-of-00001  
  inflating: xlnet_cased_L-24_H-1024_A-16/spiece.model  
  inflating: xlnet_cased_L-24_H-1024_A-16/xlnet_model.ckpt.meta  
  inflating: xlnet_cased_L-24_H-1024_A-16/xlnet_config.json  


Download extract the sts-b dataset 

In [3]:
! wget https://dl.fbaipublicfiles.com/glue/data/STS-B.zip
! unzip STS-B.zip

--2022-01-31 09:38:05--  https://dl.fbaipublicfiles.com/glue/data/STS-B.zip
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 172.67.9.4, 104.22.75.142, 104.22.74.142, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|172.67.9.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 802872 (784K) [application/zip]
Saving to: ‘STS-B.zip’


2022-01-31 09:38:06 (2.53 MB/s) - ‘STS-B.zip’ saved [802872/802872]

Archive:  STS-B.zip
   creating: STS-B/
  inflating: STS-B/LICENSE.txt       
  inflating: STS-B/dev.tsv           
   creating: STS-B/original/
  inflating: STS-B/original/sts-dev.tsv  
  inflating: STS-B/original/sts-test.tsv  
  inflating: STS-B/original/sts-train.tsv  
  inflating: STS-B/readme.txt        
  inflating: STS-B/test.tsv          
  inflating: STS-B/train.tsv         


Git clone XLNet repo for access to run_classifier and the rest of the xlnet module

In [4]:
! git clone https://github.com/zihangdai/xlnet.git

Cloning into 'xlnet'...
remote: Enumerating objects: 122, done.[K
remote: Total 122 (delta 0), reused 0 (delta 0), pack-reused 122[K
Receiving objects: 100% (122/122), 2.92 MiB | 13.17 MiB/s, done.
Resolving deltas: 100% (59/59), done.


## Define Variables
Define all the dirs: data, xlnet scripts & pretrained model. 
If you would like to save models then you can authenticate a GCP account and use that for the OUTPUT_DIR & CHECKPOINT_DIR - you will need a large amount storage to fix these models. 

Alternatively it is easy to integrate a google drive account, checkout this guide for [I/O in colab](https://colab.research.google.com/notebooks/io.ipynb) but rememeber these will take up a large amount of storage. 


In [5]:
SCRIPTS_DIR = 'xlnet' #@param {type:"string"}
DATA_DIR = 'STS-B' #@param {type:"string"}
OUTPUT_DIR = 'proc_data/STS-B' #@param {type:"string"}
PRETRAINED_MODEL_DIR = 'xlnet_cased_L-24_H-1024_A-16' #@param {type:"string"}
CHECKPOINT_DIR = 'exp/STS-B' #@param {type:"string"}

## Run Model
This will set off the fine tuning of XLNet. There are a few things to note here:


1.   This script will train and evaluate the model
2.   This will store the results locally on colab and will be lost when you are disconnected from the runtime
3.   This uses the large version of the model (base not released presently)
4.   We are using a max seq length of 128 with a batch size of 8 please refer to the [README](https://github.com/zihangdai/xlnet#memory-issue-during-finetuning) for why this is.
5. This will take approx 4hrs to run on GPU.



In [7]:
# to fix the dependecies issues
!pip install tensorflow-gpu==1.15.0

Collecting tensorflow-gpu==1.15.0
  Downloading tensorflow_gpu-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl (411.5 MB)
[K     |████████████████████████████████| 411.5 MB 7.5 kB/s 
Collecting gast==0.2.2
  Downloading gast-0.2.2.tar.gz (10 kB)
Collecting keras-applications>=1.0.8
  Downloading Keras_Applications-1.0.8-py3-none-any.whl (50 kB)
[K     |████████████████████████████████| 50 kB 6.0 MB/s 
Collecting tensorboard<1.16.0,>=1.15.0
  Downloading tensorboard-1.15.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 33.1 MB/s 
Collecting tensorflow-estimator==1.15.1
  Downloading tensorflow_estimator-1.15.1-py2.py3-none-any.whl (503 kB)
[K     |████████████████████████████████| 503 kB 41.8 MB/s 
Building wheels for collected packages: gast
  Building wheel for gast (setup.py) ... [?25l[?25hdone
  Created wheel for gast: filename=gast-0.2.2-py3-none-any.whl size=7554 sha256=c07e6bd55223cc5fe55b7c15d65339779b573cef00783dddb9aa82aebcce7ae5
  Stored in direct

In [8]:

train_command = "python xlnet/run_classifier.py \
  --do_train=True \
  --do_eval=True \
  --eval_all_ckpt=False \
  --task_name=sts-b \
  --data_dir="+DATA_DIR+" \
  --output_dir="+OUTPUT_DIR+" \
  --model_dir="+CHECKPOINT_DIR+" \
  --uncased=False \
  --spiece_model_file="+PRETRAINED_MODEL_DIR+"/spiece.model \
  --model_config_path="+PRETRAINED_MODEL_DIR+"/xlnet_config.json \
  --init_checkpoint="+PRETRAINED_MODEL_DIR+"/xlnet_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=8 \
  --eval_batch_size=8 \
  --num_hosts=1 \
  --num_core_per_host=1 \
  --learning_rate=2e-5 \
  --train_steps=4000 \
  --warmup_steps=500 \
  --save_steps=500 \
  --iterations=500"

! {train_command}





W0131 09:40:38.829209 139778454554496 module_wrapper.py:139] From xlnet/run_classifier.py:637: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.


W0131 09:40:38.829510 139778454554496 module_wrapper.py:139] From xlnet/run_classifier.py:637: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.


W0131 09:40:38.829855 139778454554496 module_wrapper.py:139] From xlnet/run_classifier.py:661: The name tf.gfile.Exists is deprecated. Please use tf.io.gfile.exists instead.


W0131 09:40:38.830362 139778454554496 module_wrapper.py:139] From xlnet/run_classifier.py:662: The name tf.gfile.MakeDirs is deprecated. Please use tf.io.gfile.makedirs instead.


W0131 09:40:38.904997 139778454554496 module_wrapper.py:139] From /content/xlnet/model_utils.py:27: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.


W0131 09:40:38.905514 139778454554496 module_wrapper.py:139] From /cont

## Running & Results
These are the results that I got from running this experiment
### Params
*    --max_seq_length=128 \
*    --train_batch_size= 8 

### Times
*   Training: 1hr 11mins
*   Evaluation: 2.5hr

### Results
*  Most accurate model on final step
*  Accuracy: 0.92416, eval_loss: 0.31708


### Model

*   The trained model checkpoints can be found in 'exp/imdb'

