# Install dependencies 

In [1]:
!pip install -q transformers==4.3.2

!pip install -q torch

!pip install -q numpy
!pip install -q pandas

# Import libraries

In [2]:
import sys
import numpy as np
import pandas as pd
import os
os.chdir('..')
from GanBert import GanBert



# Set-up

Set gloabal variable for changing dataset.

Set data_name to the name of your dataset. This needs to correspond to a folder in /data/, which should be generated by the generate_data.ipynb notebook. num_classes manually needs to be set to the number of classes in your dataset.

- data_name:
    - "imdb"
    - "medical"


In [3]:
## gloabal variable for changing dataset.
## data_name possible values: "imdb", "medical"
data_name = "medical"

## Get data
- get the path for the labeled training data, 5, 10, 25, and 50 per label
- get the path for the unlabeled training data, 5, 10, 25, and 50 per label
- get the path for the test set

In [4]:
## method for convcerting the dataset to the right format for the ganbert model
def data_to_ganbert(data):    
    data = list(zip(data.iloc[:,0],data.iloc[:,1]))
    return data

In [5]:
## get the path for the labeled training data, 5, 10, 25, and 50 per label
data_path = 'data/'+data_name
files = os.listdir(data_path)
labeled_files = [data_path+"/"+file for file in files if "train_labeled" in file]

## get the path for the unlabeled training data, 5, 10, 25, and 50 per label
unlabeled = data_to_ganbert(pd.read_csv(f"data/{data_name}/train_unlabeled.csv")) 

## get the path for the test set
test = data_to_ganbert(pd.read_csv(f"data/{data_name}/test.csv"))

## Set hyper-parameters for Bert

For more details, see https://github.com/crux82/ganbert-pytorch/blob/main/GANBERT_pytorch.ipynb


In [6]:
# hyper parameters 
batch_size = 64
max_seq_length = 128
seed = 0
learning_rate = 5e-5
epochs=5

# create a data frame to store the results
results=pd.DataFrame(columns=["n_per_class", "accuracy"])

# Traning and evalueate the Bert classifier

In [None]:
## train and evaluate bert for each data set. 
for n_per_class in [5,10,25,50]:
    data_file = ""
    result = {"n_per_class":n_per_class}
    # create model
    ganbert = GanBert(batch_size=batch_size,max_seq_length= max_seq_length,epochs = epochs,
                      learning_rate_discriminator = learning_rate,learning_rate_generator = learning_rate,
                      print_each_n_step = 100,random_state = seed)
    ## find correct file
    for file in labeled_files:
        if f"data/{data_name}/train_labeled_{n_per_class}.csv" == file:
            data_file = file
            break
    print(data_file)
    labeled = data_to_ganbert(pd.read_csv(data_file))
    ## train and evaluate the model
    performance = ganbert.train(labeled,unlabeled, test)
    ## add to resutl data frame
    result["accuracy"] = performance
    results = results.append(result,ignore_index=True)

There are 1 GPU(s) available.
We will use the GPU: Tesla K80
data/medical/train_labeled_5.csv

Training...

  Average training loss generetor: 0.253
  Average training loss discriminator: 3.651
  Training epcoh took: 0:00:33

Running Test...
  Accuracy: 0.208
  Test Loss: 1.729
  Test took: 0:00:06

Training...

  Average training loss generetor: 0.375
  Average training loss discriminator: 2.887
  Training epcoh took: 0:00:33

Running Test...
  Accuracy: 0.246
  Test Loss: 1.691
  Test took: 0:00:06

Training...

  Average training loss generetor: 0.609
  Average training loss discriminator: 2.052
  Training epcoh took: 0:00:33

Running Test...
  Accuracy: 0.225
  Test Loss: 1.830
  Test took: 0:00:06

Training...

  Average training loss generetor: 0.848
  Average training loss discriminator: 1.425
  Training epcoh took: 0:00:33

Running Test...
  Accuracy: 0.252
  Test Loss: 1.946
  Test took: 0:00:06

Training...

  Average training loss generetor: 0.840
  Average training loss dis

In [None]:
## print the result
results

In [None]:
# write the result to file
if not os.path.exists('results'):
      os.mkdir('results')
result_path = f'results/{data_name}'
if not os.path.exists(result_path):
      os.mkdir(result_path)
results.to_csv(f"{result_path}/GanBert_results.csv", index=False)