In [50]:
import sys
sys.path.append('../utils')

from datetime import datetime,timedelta

from ganbert_utils import *
from ganbert_models import *


In [51]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [52]:
##Set random values

seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed_val)

In [53]:
device = get_gpu_details()

There are 1 GPU(s) available.
We will use the GPU: Tesla V100-SXM2-16GB





In [54]:
#--------------------------------
#  Read Fine Tuning Parameters from FineTuning Configuration File 
#--------------------------------

from ganbert_finetuning_config import *

In [55]:
#--------------------------------
#  Load the Transformers Models 
#--------------------------------
transformer = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [56]:
#--------------------------------
#  Extract the Dataset
#--------------------------------

labeled_examples, unlabeled_examples, _ = get_sst_examples('./../../data/SST-2/train.tsv',test=False,discard_values = 0.99)
_, _, test_examples = get_sst_examples('./../../data/SST-2/dev.tsv', test=True,discard_values = 0)

In [60]:
#--------------------------------
#  Prepare the Training Dataset
#--------------------------------
label_map = {}

for (i, label) in enumerate(label_list):
    label_map[label] = i

train_examples = labeled_examples
#The labeled (train) dataset is assigned with a mask set to True

train_label_masks = np.ones(len(labeled_examples), dtype=bool)

#If unlabel examples are available
if unlabeled_examples:
    train_examples = train_examples + unlabeled_examples
    #The unlabeled (train) dataset is assigned with a mask set to False 
    tmp_masks = np.zeros(len(unlabeled_examples), dtype=bool)
    train_label_masks = np.concatenate([train_label_masks,tmp_masks])
    
    
train_dataloader = generate_data_loader(train_examples, 
                                        train_label_masks, 
                                        label_map, 
                                        tokenizer, 
                                        batch_size=64, 
                                        do_shuffle = True
                                       )  
    

examples len is:  681


In [61]:
#------------------------------
#   Prepare the Test Dataset
#------------------------------
#The labeled (test) dataset is assigned with a mask set to True
test_label_masks = np.ones(len(test_examples), dtype=bool)

test_dataloader = generate_data_loader(test_examples, 
                                       test_label_masks, 
                                       label_map, 
                                       tokenizer, 
                                       batch_size=64,
                                       do_shuffle = False)

examples len is:  1


In [62]:
#------------------------------
#   Prepare the Generator and Discriminator Outputs 
#------------------------------

# The config file is required to get the dimension of the vector produced by 
# the underlying transformer
config = AutoConfig.from_pretrained(model_name)

# currently this notebook has the hidden size of the encoder from bert with 768
hidden_size = int(config.hidden_size)


# Define the number and width of hidden layers
hidden_levels_g = [hidden_size for i in range(0, num_hidden_layers_g)]
hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)]


#-------------------------------------------------
#   Instantiate the Generator and Discriminator
#-------------------------------------------------
generator = Generator(noise_size=noise_size, 
                      output_size=hidden_size, 
                      hidden_sizes=hidden_levels_g, 
                      dropout_rate=out_dropout_rate
                     )

discriminator = Discriminator(input_size=hidden_size, 
                              hidden_sizes=hidden_levels_d, 
                              num_labels=len(label_list), 
                              dropout_rate=out_dropout_rate
                             )

hidden_sizes[i] and [i+1] is 768 and 768


In [63]:
hidden_levels_d

[768]

In [64]:
#-------------------------------------------------
#   Transfer Objects to GPU
#-------------------------------------------------

if torch.cuda.is_available():    
    generator.cuda()
    discriminator.cuda()
    transformer.cuda()
    
    if multi_gpu:
        transformer = torch.nn.DataParallel(transformer)

In [65]:
#-------------------------------------------------
#   Transfer Objects to GPU
#-------------------------------------------------
training_stats = []

# Measure the total training time for the whole run.
total_t0 = time.time()

#models parameters
transformer_vars = [i for i in transformer.parameters()]
d_vars = transformer_vars + [v for v in discriminator.parameters()]
g_vars = [v for v in generator.parameters()]

#optimizer
dis_optimizer = torch.optim.AdamW(d_vars, lr=learning_rate_discriminator)
gen_optimizer = torch.optim.AdamW(g_vars, lr=learning_rate_generator) 

#scheduler
if apply_scheduler:
    num_train_examples = len(train_examples)
    num_train_steps = int(num_train_examples / batch_size * num_train_epochs)
    num_warmup_steps = int(num_train_steps * warmup_proportion)

    scheduler_d = get_constant_schedule_with_warmup(dis_optimizer, 
                                           num_warmup_steps = num_warmup_steps)
    scheduler_g = get_constant_schedule_with_warmup(gen_optimizer, 
                                           num_warmup_steps = num_warmup_steps)
# num_train_epochs =1 
# For each epoch...
for epoch_i in range(0, num_train_epochs):
    # ========================================
    #               Training
    # ========================================
    # Perform one full pass over the training set.
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, num_train_epochs))
    print('Training...')

    # Measure how long the training epoch takes.
    t0 = time.time()

    # Reset the total loss for this epoch.
    tr_g_loss = 0
    tr_d_loss = 0

    # Put the model into training mode.
    transformer.train() 
    generator.train()
    discriminator.train()

    # For each batch of training data...
    for step, batch in enumerate(train_dataloader):

        # Progress update every print_each_n_step batches.
        if step % print_each_n_step == 0 and not step == 0:
            # Calculate elapsed time in minutes.
            elapsed = format_time(time.time() - t0)
            
            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        # Unpack this training batch from our dataloader. 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        b_label_mask = batch[3].to(device)

        real_batch_size = b_input_ids.shape[0]
     
        # Encode real data in the Transformer
        model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)
        hidden_states = model_outputs[-1]
        
        # Generate fake data that should have the same distribution of the ones
        # encoded by the transformer. 
        # First noisy input are used in input to the Generator
        noise = torch.zeros(real_batch_size, noise_size, device=device).uniform_(0, 1)
        # Gnerate Fake data
        gen_rep = generator(noise)

        # Generate the output of the Discriminator for real and fake data.
        # First, we put together the output of the tranformer and the generator
        disciminator_input = torch.cat([hidden_states, gen_rep], dim=0)
        # Then, we select the output of the disciminator
        features, logits, probs = discriminator(disciminator_input)
        
        
        
        
        # Finally, we separate the discriminator's output for the real and fake
        # data
        features_list = torch.split(features, real_batch_size)
        D_real_features = features_list[0]
        D_fake_features = features_list[1]
      
        logits_list = torch.split(logits, real_batch_size)
        D_real_logits = logits_list[0]
        D_fake_logits = logits_list[1]
        
        probs_list = torch.split(probs, real_batch_size)
        D_real_probs = probs_list[0]
        D_fake_probs = probs_list[1]

        #---------------------------------
        #  LOSS evaluation
        #---------------------------------
        # Generator's LOSS estimation
        g_loss_d = -1 * torch.mean(torch.log(1 - D_fake_probs[:,-1] + epsilon))
        g_feat_reg = torch.mean(torch.pow(torch.mean(D_real_features, dim=0) - torch.mean(D_fake_features, dim=0), 2))
        g_loss = g_loss_d + g_feat_reg
  
        # Disciminator's LOSS estimation
        logits = D_real_logits[:,0:-1]
        log_probs = F.log_softmax(logits, dim=-1)
        # The discriminator provides an output for labeled and unlabeled real data
        # so the loss evaluated for unlabeled data is ignored (masked)
        label2one_hot = torch.nn.functional.one_hot(b_labels, len(label_list))
        per_example_loss = -torch.sum(label2one_hot * log_probs, dim=-1)
        per_example_loss = torch.masked_select(per_example_loss, b_label_mask.to(device))
        labeled_example_count = per_example_loss.type(torch.float32).numel()

        # It may be the case that a batch does not contain labeled examples, 
        # so the "supervised loss" in this case is not evaluated
        if labeled_example_count == 0:
            D_L_Supervised = 0
        else:
            D_L_Supervised = torch.div(torch.sum(per_example_loss.to(device)), labeled_example_count)
                 
        D_L_unsupervised1U = -1 * torch.mean(torch.log(1 - D_real_probs[:, -1] + epsilon))
        D_L_unsupervised2U = -1 * torch.mean(torch.log(D_fake_probs[:, -1] + epsilon))
        d_loss = D_L_Supervised + D_L_unsupervised1U + D_L_unsupervised2U

        #---------------------------------
        #  OPTIMIZATION
        #---------------------------------
        # Avoid gradient accumulation
        gen_optimizer.zero_grad()
        dis_optimizer.zero_grad()

        # Calculate weigth updates
        # retain_graph=True is required since the underlying graph will be deleted after backward
        g_loss.backward(retain_graph=True)
        d_loss.backward() 
        
        # Apply modifications
        gen_optimizer.step()
        dis_optimizer.step()

        # A detail log of the individual losses
        #print("{0:.4f}\t{1:.4f}\t{2:.4f}\t{3:.4f}\t{4:.4f}".
        #      format(D_L_Supervised, D_L_unsupervised1U, D_L_unsupervised2U,
        #             g_loss_d, g_feat_reg))

        # Save the losses to print them later
        tr_g_loss += g_loss.item()
        tr_d_loss += d_loss.item()

        # Update the learning rate with the scheduler
        if apply_scheduler:
            scheduler_d.step()
            scheduler_g.step()

    # Calculate the average loss over all of the batches.
    avg_train_loss_g = tr_g_loss / len(train_dataloader)
    avg_train_loss_d = tr_d_loss / len(train_dataloader)             
    
    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss generetor: {0:.3f}".format(avg_train_loss_g))
    print("  Average training loss discriminator: {0:.3f}".format(avg_train_loss_d))
    print("  Training epcoh took: {:}".format(training_time))
        
    # ========================================
    #     TEST ON THE EVALUATION DATASET
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our test set.
    print("")
    print("Running Test...")

    t0 = time.time()

    # Put the model in evaluation mode--the dropout layers behave differently
    # during evaluation.
    transformer.eval() #maybe redundant
    discriminator.eval()
    generator.eval()

    # Tracking variables 
    total_test_accuracy = 0
   
    total_test_loss = 0
    nb_test_steps = 0

    all_preds = []
    all_labels_ids = []

    #loss
    nll_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)

    # Evaluate data for one epoch
    for batch in test_dataloader:
        
        # Unpack this training batch from our dataloader. 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        
        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():        
            model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)
            hidden_states = model_outputs[-1]
            _, logits, probs = discriminator(hidden_states)
            print("Evaluation",hidden_states)
            ###log_probs = F.log_softmax(probs[:,1:], dim=-1)
            filtered_logits = logits[:,0:-1]
            # Accumulate the test loss.
            total_test_loss += nll_loss(filtered_logits, b_labels)
            
        # Accumulate the predictions and the input labels
        _, preds = torch.max(filtered_logits, 1)
        all_preds += preds.detach().cpu()
        all_labels_ids += b_labels.detach().cpu()

    # Report the final accuracy for this validation run.
    all_preds = torch.stack(all_preds).numpy()
    all_labels_ids = torch.stack(all_labels_ids).numpy()
    test_accuracy = np.sum(all_preds == all_labels_ids) / len(all_preds)
    print("  Accuracy: {0:.3f}".format(test_accuracy))

    # Calculate the average loss over all of the batches.
    avg_test_loss = total_test_loss / len(test_dataloader)
    avg_test_loss = avg_test_loss.item()
    
    # Measure how long the validation run took.
    test_time = format_time(time.time() - t0)
    
    print("  Test Loss: {0:.3f}".format(avg_test_loss))
    print("  Test took: {:}".format(test_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss generator': avg_train_loss_g,
            'Training Loss discriminator': avg_train_loss_d,
            'Valid. Loss': avg_test_loss,
            'Valid. Accur.': test_accuracy,
            'Training Time': training_time,
            'Test Time': test_time
        }
    )


Training...
  Batch    10  of     11.    Elapsed: 0:00:03.

  Average training loss generetor: 0.457
  Average training loss discriminator: 2.160
  Training epcoh took: 0:00:04

Running Test...
Evaluation tensor([[-7.4150e-01,  6.8245e-01,  9.9999e-01, -9.9925e-01,  9.3446e-01,
          6.9396e-01,  9.9420e-01, -9.7315e-01, -9.7911e-01,  2.1464e-01,
          9.9354e-01,  9.9985e-01, -9.7603e-01, -9.9991e-01,  6.1979e-01,
         -9.9032e-01,  9.9743e-01, -8.9881e-01, -1.0000e+00, -2.1543e-01,
         -4.6660e-01, -9.9999e-01,  2.6799e-01,  9.3068e-01,  9.5611e-01,
          4.8368e-01,  9.9810e-01,  1.0000e+00,  8.8602e-01,  5.5357e-01,
          5.9344e-01, -9.9866e-01,  8.7762e-01, -9.9997e-01,  5.4566e-01,
         -1.6473e-01,  8.0411e-01, -6.8427e-01,  9.0533e-01, -8.4689e-01,
         -5.9253e-01, -6.7736e-01,  4.5640e-02, -7.3658e-01,  7.0170e-01,
          4.8934e-01,  4.7939e-01,  3.4084e-01, -1.2402e-01,  9.9998e-01,
         -9.7787e-01,  1.0000e+00, -9.7183e-01,  9.998

  Batch    10  of     11.    Elapsed: 0:00:03.

  Average training loss generetor: 0.668
  Average training loss discriminator: 1.357
  Training epcoh took: 0:00:04

Running Test...
Evaluation tensor([[-8.5279e-01,  7.7288e-01,  9.9999e-01, -9.9861e-01,  7.7544e-01,
          2.2096e-01,  9.7110e-01, -9.3852e-01, -9.3704e-01,  5.6132e-01,
          9.9160e-01,  9.9963e-01, -7.5562e-01, -9.9963e-01, -4.6161e-01,
         -9.6348e-01,  9.9060e-01, -9.2539e-01, -9.9999e-01,  8.0202e-01,
         -2.1448e-01, -9.9998e-01,  1.3397e-01,  7.3605e-01,  8.2761e-01,
          5.5782e-01,  9.9611e-01,  9.9999e-01,  7.3785e-01,  9.2882e-01,
          6.2071e-01, -9.9556e-01,  1.5654e-01, -9.9996e-01,  5.0783e-01,
         -1.2384e-01, -1.4373e-01, -7.5777e-01,  6.1309e-02, -6.7878e-01,
         -4.2952e-01,  3.8882e-01, -5.0619e-01, -7.7621e-01,  2.0990e-01,
          5.4458e-01,  5.8499e-01,  2.4094e-01, -2.5250e-01,  9.9998e-01,
         -9.3519e-01,  1.0000e+00, -8.9717e-01,  9.9973e-01,  9.955

  Batch    10  of     11.    Elapsed: 0:00:03.

  Average training loss generetor: 0.836
  Average training loss discriminator: 0.818
  Training epcoh took: 0:00:04

Running Test...
Evaluation tensor([[-0.9492,  0.9070,  1.0000, -0.9986,  0.8859, -0.9606,  0.9923,  0.8147,
         -0.9540,  0.0823,  0.9951,  0.9995,  0.9789, -0.9999, -0.9936, -0.9744,
          0.9887, -0.9723, -1.0000,  0.9912,  0.6667, -1.0000,  0.3785, -0.7506,
          0.9519,  0.5295,  0.9970,  1.0000,  0.8861,  0.9973,  0.7669, -0.9947,
         -0.9471, -1.0000,  0.6547, -0.5219, -0.9792, -0.8559, -0.9202,  0.6664,
         -0.8546,  0.9735, -0.9693, -0.8377, -0.9750,  0.7970,  0.8623,  0.2750,
         -0.6277,  1.0000, -0.9883,  1.0000,  0.8983,  1.0000,  0.9960,  0.7435,
          0.9964,  0.7540,  0.8464,  0.9589,  0.8812, -0.8493,  0.9723, -0.5911,
         -0.9940, -0.9632,  0.9709,  0.8433, -0.7888,  0.9317,  0.6503,  0.7370,
          0.9987, -0.9972, -0.6725, -0.9838,  0.9906, -1.0000,  0.9904,  1.000

  Batch    10  of     11.    Elapsed: 0:00:03.

  Average training loss generetor: 0.722
  Average training loss discriminator: 0.892
  Training epcoh took: 0:00:04

Running Test...
Evaluation tensor([[-0.9267,  0.8891,  1.0000, -0.9946,  0.7959, -0.9762,  0.9569,  0.7718,
         -0.8230,  0.2380,  0.9827,  0.9976,  0.9740, -0.9993, -0.9968, -0.9327,
          0.9677, -0.9646, -1.0000,  0.9936,  0.7496, -1.0000,  0.2414, -0.7155,
          0.7781,  0.5161,  0.9878,  1.0000,  0.4138,  0.9993,  0.6388, -0.9859,
         -0.9709, -0.9999,  0.6213, -0.6095, -0.9913, -0.8205, -0.9716,  0.8380,
         -0.6895,  0.9879, -0.9832, -0.7869, -0.9911,  0.6478,  0.9147,  0.1459,
         -0.6455,  1.0000, -0.9587,  1.0000,  0.8618,  1.0000,  0.9908,  0.7157,
          0.9877,  0.7478,  0.9086,  0.9772,  0.5953, -0.7342,  0.9548, -0.8008,
         -0.9982, -0.8873,  0.9893,  0.8068, -0.6242,  0.8681,  0.4570,  0.7181,
          0.9976, -0.9956, -0.7238, -0.9771,  0.9961, -1.0000,  0.9810,  1.000

  Batch    10  of     11.    Elapsed: 0:00:03.

  Average training loss generetor: 0.732
  Average training loss discriminator: 0.792
  Training epcoh took: 0:00:04

Running Test...
Evaluation tensor([[-0.9411,  0.8777,  1.0000, -0.9971,  0.8703, -0.9706,  0.9735,  0.6175,
         -0.9138,  0.3210,  0.9904,  0.9985,  0.9367, -0.9997, -0.9962, -0.9519,
          0.9799, -0.9644, -1.0000,  0.9849,  0.5631, -1.0000,  0.2443, -0.6994,
          0.8990,  0.5663,  0.9911,  1.0000,  0.7015,  0.9989,  0.6387, -0.9898,
         -0.9750, -0.9999,  0.6598, -0.4560, -0.9933, -0.7994, -0.9635,  0.7570,
         -0.8137,  0.9788, -0.9746, -0.7479, -0.9591,  0.7021,  0.8957,  0.1449,
         -0.6121,  1.0000, -0.9776,  1.0000,  0.6270,  1.0000,  0.9932,  0.7632,
          0.9934,  0.7547,  0.8071,  0.9629,  0.7189, -0.6335,  0.9762, -0.6724,
         -0.9974, -0.8926,  0.9853,  0.7700, -0.6093,  0.8182,  0.6151,  0.6421,
          0.9978, -0.9960, -0.7114, -0.9852,  0.9957, -1.0000,  0.9873,  1.000

  Batch    10  of     11.    Elapsed: 0:00:03.

  Average training loss generetor: 0.717
  Average training loss discriminator: 0.741
  Training epcoh took: 0:00:04

Running Test...
Evaluation tensor([[-0.9504,  0.8993,  1.0000, -0.9983,  0.9063, -0.9394,  0.9692,  0.5988,
         -0.9274,  0.3649,  0.9929,  0.9991,  0.9026, -0.9998, -0.9955, -0.9705,
          0.9849, -0.9723, -1.0000,  0.9850,  0.5998, -1.0000,  0.3268, -0.6122,
          0.9241,  0.5913,  0.9937,  1.0000,  0.7081,  0.9993,  0.6597, -0.9907,
         -0.9703, -0.9999,  0.7616, -0.4935, -0.9938, -0.8376, -0.9693,  0.6636,
         -0.8529,  0.9754, -0.9773, -0.7782, -0.9473,  0.7381,  0.9165,  0.1599,
         -0.6860,  1.0000, -0.9834,  1.0000,  0.3742,  1.0000,  0.9963,  0.8319,
          0.9954,  0.7901,  0.7759,  0.9717,  0.7636, -0.6454,  0.9855, -0.6268,
         -0.9988, -0.9304,  0.9891,  0.7784, -0.6556,  0.8070,  0.6733,  0.6943,
          0.9986, -0.9964, -0.7837, -0.9887,  0.9969, -1.0000,  0.9911,  1.000

In [45]:
for stat in training_stats:
    print(stat)

print("\nTraining complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))

{'epoch': 1, 'Training Loss generator': 0.45681632377884607, 'Training Loss discriminator': 2.159539677880027, 'Valid. Loss': 0.5004576444625854, 'Valid. Accur.': 1.0, 'Training Time': '0:00:04', 'Test Time': '0:00:00'}
{'epoch': 2, 'Training Loss generator': 0.6681772470474243, 'Training Loss discriminator': 1.3571972521868618, 'Valid. Loss': 0.20002782344818115, 'Valid. Accur.': 1.0, 'Training Time': '0:00:04', 'Test Time': '0:00:00'}
{'epoch': 3, 'Training Loss generator': 0.8358810977502302, 'Training Loss discriminator': 0.818211089481007, 'Valid. Loss': 0.036212772130966187, 'Valid. Accur.': 1.0, 'Training Time': '0:00:04', 'Test Time': '0:00:00'}
{'epoch': 4, 'Training Loss generator': 0.7768805243752219, 'Training Loss discriminator': 0.8545701612125743, 'Valid. Loss': 0.014373105950653553, 'Valid. Accur.': 1.0, 'Training Time': '0:00:04', 'Test Time': '0:00:00'}
{'epoch': 5, 'Training Loss generator': 0.7217735268852927, 'Training Loss discriminator': 0.8918275508013639, 'Vali

In [46]:
torch.save(
                {
                    'tokenizer': tokenizer,
                    'bert_encoder': transformer.state_dict(),
                    'discriminator': discriminator.state_dict()
            }, 
    f"gan_bert_finetuned_sst2_{len(train_examples)}_samples_{datetime.now():%Y-%m-%d_%H-%M-%S%z}.pt")




# Rough work

In [None]:
probs

In [None]:
train_dataloader.batch_size

In [None]:
dataset = train_dataloader.dataset.tensors

In [None]:
train_label_masks.shape

In [None]:
type(transformer)

In [None]:
tmp_list = []

tmp_list.append(AutoModel.from_pretrained(model_name))
tmp_list.append(Discriminator(input_size=hidden_size, 
                              hidden_sizes=hidden_levels_d, 
                              num_labels=len(label_list), 
                              dropout_rate=out_dropout_rate
                             ).cuda(0))

In [None]:
from torch.distributed.pipeline.sync import Pipe
import tempfile

from torch.distributed import rpc
tmpfile = tempfile.NamedTemporaryFile()
rpc.init_rpc(
    name="worker",
    rank=0,
    world_size=1,
    rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
        init_method="file://{}".format(tmpfile.name),
        # Specifying _transports and _channels is a workaround and we no longer
        # will have to specify _transports and _channels for PyTorch
        # versions >= 1.8.1
        _transports=["ibv", "uv"],
        _channels=["cuda_ipc", "cuda_basic"],
    )
)

save_model = Pipe(torch.nn.Sequential(*tmp_list))

In [None]:
torch.save(save_model, 'pipe_obj.pt')

In [None]:
class InferenceGANBert(nn.Module):
    
    def __init__(self, transformer, discriminator):
        super().__init__()
        self.transformer = transformer
        self.transformer.eval() 
        self.discriminator = discriminator
        self.discriminator.eval()
        
    def forward(self, dataloader, batch_size=64):
        # do the forward pass
        
        device = get_gpu_details()
        
        for batch in dataloader:
            # Unpack this training batch from our dataloader. 
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)
            with torch.no_grad():        
                model_outputs = self.transformer(b_input_ids, attention_mask=b_input_mask)
                hidden_states = model_outputs[-1]
                _, _, probs = self.discriminator(hidden_states)
                predicted_probs.extend(probs)
                
        return predicted_probs
            

In [15]:
type(transformer)

torch.nn.parallel.data_parallel.DataParallel

In [18]:
type(discriminator)

ganbert_models.Discriminator

In [19]:
type(tokenizer)

transformers.models.bert.tokenization_bert_fast.BertTokenizerFast