Clone the repo:
git clone https://github.com/IBM/DP-TabTransformer.git
cd DP-TabTransformer
environment configuration:
conda create -n DP-TabTransformer python=3.9
conda activate DP-TabTransformer
pip install -r requirements.txt
You are all set! 🎉
We provide a simple training scripts run_train_from_scratch.py ,to train TabTransformer from scratch, simply run
python run_train_from_scratch.py
we use ACSIncome_IN dataset by default, to use your own data, just put it on data fold, and replace the data path in train.py accordingly.
We offer a straightforward script, run_dp_sgd.py, which performs pre-training and fine-tuning of the TabTransformer using DP-SGD. This is done using the ACSIncome_CA dataset for pre-training and the ACSIncome_IN dataset for fine-tuning. To experiment with it, simply execute the following command:
python run_dp_sgd.py
By default, this script employs Shallow Tuning. To utilize other methods, you can enable them as follows:
- For
LoRA, setuse_lora=Trueinrun_dp_sgd.py. - To use
Adapter, setuse_adapter=True. - For
Deep Tuning, activatedeep_tuning=True. - To enable full tuning, set
full_tuning=Trueaccordingly.