# Deep Learning Project - Part3
<div style="text-align: center">
<h1 style = "color: red"> Sharif University Of Technology</h1>
<h2 style = "color: green"> DR. Fatemizadeh </h2>
<h3 style = "color: cyan"> Authors: Amirreza Velaee - Hessam Hosseini - Amirabbas Afzali - Mahshad Moradi<h3>
</div>

In [None]:
import argparse
import os
import random
import torch
import torch.nn as nn 
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import json
from tqdm import tqdm, trange
from sklearn.metrics import precision_recall_fscore_support, matthews_corrcoef
import pickle
from torch.utils.data import random_split
from torch.utils.data import Dataset, Subset
from torch.utils.data import DataLoader,Dataset
from torch.nn.modules import ReLU,Linear,Dropout
import time
import math
import datetime
import torch.nn.functional as F
from collections import OrderedDict


# Set random seed for reproducibility
manualSeed = 42
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

ngpu = 1

In [None]:
# If there's a GPU available...
if torch.cuda.is_available():    
    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")
    
ngpu = torch.cuda.device_count()

## load the dataset:

In [None]:
!pip install gdown 
import gdown 

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

# id = "11YeloR2eTXcTzdwI04Z-M2QVvIeQAU6-"
id = "1G-XttJCGvkAVkU9N_W_PxR0Cx099qbUa"
gdown.download_folder(id=id, quiet=True, use_cookies=False)

**Subtask B:**

An object of the JSON has the following format:


-  **id** -> identifier of the example,
- **label** -> label (human: 0, chatGPT: 1, cohere: 2, davinci: 3, bloomz: 4, dolly: 5),
- **text** -> text generated by machine or written by human,
- **model** -> model name that generated data,
- **source** -> source (Wikipedia, Wikihow, Peerread, Reddit, Arxiv) on English


In [None]:
# content/drive/My Drive/Project
with open('/kaggle/working/SubtaskB/subtaskB_train.jsonl', 'r') as file:    
    lines = file.readlines()

# Parse each line as a JSON object
train_objects = [json.loads(line) for line in lines] 

In [None]:
with open('/kaggle/working/SubtaskB/subtaskB_dev.jsonl', 'r') as file:
    lines = file.readlines()

dev_objects = [json.loads(line) for line in lines]

In [None]:
len(train_objects), len(dev_objects) 

an example:

In [None]:
train_objects[100].keys()

In [None]:
train_objects[100]['model'], train_objects[100]['source'],train_objects[100]['label'],train_objects[100]['id']

In [None]:
print(train_objects[104]['text'])

More details about the dataset and the Exploratory Data Analysis have been reported in `EDA.ipynb`.

## load the pretrained `RoBERTa`/`DidstilBert` from $huggingface$:

In [None]:
!pip install transformers
!pip install sentencepiece

In [None]:
from transformers import RobertaTokenizer, RobertaModel
from transformers import DistilBertTokenizer, DistilBertModel

import sentencepiece
from transformers import get_constant_schedule_with_warmup

In [None]:
# Load pre-trained BERT model and tokenizer
# model_name = 'roberta-large'
# tokenizer = RobertaTokenizer.from_pretrained(model_name)
# bert_model = RobertaModel.from_pretrained(model_name)

model_name = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
bert_model = DistilBertModel.from_pretrained(model_name)

bert_generator = DistilBertModel.from_pretrained(model_name)

In [None]:
class BERT_Embedder(nn.Module):
    def __init__(self, bert_modele):
        super(BERT_Embedder, self).__init__()
        self.bert = bert_modele

    def forward(self, encoded_ids,attention_mask):
        outputs = self.bert(encoded_ids,attention_mask)
        last_hidden_states = outputs.last_hidden_state[:,0]  # return embedding of 'CLS' token for classification.

        return last_hidden_states

# **Bag of Words**

## Data cleaning

- Expand Contractions

In [None]:
# For regular expressions
import re

# Dictionary of English Contractions
contractions_dict = { "ain't": "are not","'s":" is","aren't": "are not",
                     "can't": "cannot","can't've": "cannot have",
                     "'cause": "because","could've": "could have","couldn't": "could not",
                     "couldn't've": "could not have", "didn't": "did not","doesn't": "does not",
                     "don't": "do not","hadn't": "had not","hadn't've": "had not have",
                     "hasn't": "has not","haven't": "have not","he'd": "he would",
                     "he'd've": "he would have","he'll": "he will", "he'll've": "he will have",
                     "how'd": "how did","how'd'y": "how do you","how'll": "how will",
                     "I'd": "I would", "I'd've": "I would have","I'll": "I will",
                     "I'll've": "I will have","I'm": "I am","I've": "I have", "isn't": "is not",
                     "it'd": "it would","it'd've": "it would have","it'll": "it will",
                     "it'll've": "it will have", "let's": "let us","ma'am": "madam",
                     "mayn't": "may not","might've": "might have","mightn't": "might not", 
                     "mightn't've": "might not have","must've": "must have","mustn't": "must not",
                     "mustn't've": "must not have", "needn't": "need not",
                     "needn't've": "need not have","o'clock": "of the clock","oughtn't": "ought not",
                     "oughtn't've": "ought not have","shan't": "shall not","sha'n't": "shall not",
                     "shan't've": "shall not have","she'd": "she would","she'd've": "she would have",
                     "she'll": "she will", "she'll've": "she will have","should've": "should have",
                     "shouldn't": "should not", "shouldn't've": "should not have","so've": "so have",
                     "that'd": "that would","that'd've": "that would have", "there'd": "there would",
                     "there'd've": "there would have", "they'd": "they would",
                     "they'd've": "they would have","they'll": "they will",
                     "they'll've": "they will have", "they're": "they are","they've": "they have",
                     "to've": "to have","wasn't": "was not","we'd": "we would",
                     "we'd've": "we would have","we'll": "we will","we'll've": "we will have",
                     "we're": "we are","we've": "we have", "weren't": "were not","what'll": "what will",
                     "what'll've": "what will have","what're": "what are", "what've": "what have",
                     "when've": "when have","where'd": "where did", "where've": "where have",
                     "who'll": "who will","who'll've": "who will have","who've": "who have",
                     "why've": "why have","will've": "will have","won't": "will not",
                     "won't've": "will not have", "would've": "would have","wouldn't": "would not",
                     "wouldn't've": "would not have","y'all": "you all", "y'all'd": "you all would",
                     "y'all'd've": "you all would have","y'all're": "you all are",
                     "y'all've": "you all have", "you'd": "you would","you'd've": "you would have",
                     "you'll": "you will","you'll've": "you will have", "you're": "you are",
                     "you've": "you have"}

# Regular expression for finding contractions
contractions_re=re.compile('(%s)' % '|'.join(contractions_dict.keys()))

# Function for expanding contractions
def expand_contractions(text,contractions_dict=contractions_dict):
    def replace(match):
        return contractions_dict[match.group(0)]
    return contractions_re.sub(replace, text)


- Remove digits and words containing digits

In [None]:
def demove_digts(x):
    return re.sub('\w*\d\w*','', x)

- Remove Punctuations

In [None]:
import string

def demove_punctuations(x):
    out = re.sub('[%s]' % re.escape(string.punctuation), '', x)
    return re.sub(' +',' ',out) # Removing extra spaces

In [None]:
from collections import Counter

class Fake_Dataset(Dataset):
    """
    Generate the fake sentences as inputs of G2 (BERT Generator) 
    """
    def __init__(self, json_file,seq_length, tokenizer):

        self.json_file = json_file
        self.seq_length = seq_length
        self.tokenizer = tokenizer
        self.weighted_keys = self.create_distribiution()

    def __len__(self):
        return len(self.json_file)
    
    def text_cleaner(self,text):
         return demove_punctuations(demove_digts(expand_contractions(text)))
    
    
    def create_distribiution(self):
        token_counts = Counter([]) 
        for sample in tqdm(self.json_file):
            text = self.text_cleaner(sample['text'])
            tokens_a = self.tokenizer.tokenize(text)
            input_ids = self.tokenizer.convert_tokens_to_ids(tokens_a)
            token_counts += Counter(input_ids) 

        token_counts_dict = dict(token_counts)
        weighted_keys = [key for key, count in token_counts_dict.items() for _ in range(count)]
        return weighted_keys
        

    def sampler(self,seq_length):
        input_ids = random.choices(self.weighted_keys, k=seq_length)
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        input_mask = torch.ones(input_ids.shape[0],dtype=torch.int32)
        return input_ids, input_mask 

    def __getitem__(self,idx):
        input_ids, input_mask = self.sampler(self.seq_length)
        return input_ids, input_mask #input_ids.squeeze(0)

Now we can define **Discriminator** and **Generator** completely:

In [None]:
# custom weights initialization
import torch.nn.init as init

def custom_weights_init(m):
    if isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.constant_(m.bias.data, 0)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_size, num_classes,dropout_rate=0.2,relu_slop=0.2):
        super(Discriminator, self).__init__()
        self.input_size = input_size
        self.num_classes = num_classes

        self.main = torch.nn.Sequential(
            Dropout(p=dropout_rate),
            
            Linear(in_features=self.input_size, out_features=512, bias=True),
            nn.LeakyReLU(relu_slop, inplace=True),
            Dropout(p=dropout_rate),
            
            Linear(in_features=512, out_features=256, bias=True),
            nn.LeakyReLU(relu_slop, inplace=True),
            Dropout(p=dropout_rate),
            
            Linear(in_features=256, out_features=256, bias=True),
            nn.LeakyReLU(relu_slop, inplace=True),
            Dropout(p=dropout_rate),
        )

        self.logit = nn.Linear(256,self.num_classes+1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input):
        last_rep = self.main(input)  # for do 'feature matching'
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs


In [None]:
class Generator1(nn.Module):
    def __init__(self, input_size, output_size,dropout_rate=0.2,relu_slop=0.2):
        super(Generator1, self).__init__()

        self.input_size = input_size
        self.output_size = output_size

        self.main = torch.nn.Sequential(
            Linear(in_features=self.input_size, out_features=256, bias=True),
            nn.LeakyReLU(relu_slop, inplace=True),
            Dropout(p=dropout_rate, inplace=False),
            Linear(in_features=256, out_features=self.output_size, bias=True),
        )

    def forward(self, input):
        return self.main(input)


In [None]:
class Generator2(nn.Module):
    def __init__(self, bert_modele):
        super(Generator2, self).__init__()
        self.bert = bert_modele

    def forward(self, encoded_ids,attention_mask):
        outputs = self.bert(encoded_ids,attention_mask)
        last_hidden_states = outputs.last_hidden_state[:,0]  # return embedding of 'CLS' token for classification.

        return last_hidden_states

Set Hyperparameter :

In [None]:
num_classes = 6
input_size = 768
noise_size = 100
label_list = list(range(6))

### Define the Dataset class:

In [None]:

class SemEval_Dataset(Dataset):
    def __init__(self, json_file,label_list,label_masks,
                 max_seq_length, tokenizer,dtype=torch.long):

        self.json_file = json_file
        self.dtype = dtype
        self.label_list = label_list # [0, 1, 2, 3, 4, 5]
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.label_masks = label_masks

    def __len__(self):
        return len(self.json_file)

    def feature_extractor(self, text, label=None):
        features = []
        tokenized_text = tokenizer(text,padding='max_length', truncation=True,
                                   max_length=self.max_seq_length,
                                   return_tensors="pt")

        input_ids = tokenized_text['input_ids']
        input_mask = tokenized_text['attention_mask']

        if len(input_ids) > self.max_seq_length:
            input_ids = input_ids[0:(self.max_seq_length)]   # crop long sentences
            input_mask = input_mask[0:(self.max_seq_length)]

        assert len(input_ids[0]) == self.max_seq_length
        assert len(input_mask[0]) == self.max_seq_length

        if label != None:
            return input_ids, input_mask, label
        else:
            return input_ids, input_mask

    def __getitem__(self, idx):
        data = self.json_file[idx]
        input_ids, input_mask, label_id = self.feature_extractor(data['text'], label=data['label'])

        return input_ids.squeeze(0), input_mask.squeeze(0), data['label'], self.label_masks[idx]

Create Dataset and Dataloader :

In [None]:
max_seq_length = 256
batch_size = 64

In [None]:
unlabeled_examples = True
labeled_ratio = 0.5               # 0.01, 0.1 ,0.05 ,0.5 
train_dataset_size_labeled = int(labeled_ratio* len(train_objects))

#The labeled (train) dataset is assigned with a mask set to True
train_label_masks = torch.ones(train_dataset_size_labeled, dtype=bool)
#If unlabel examples are available
if unlabeled_examples:
  #The unlabeled (train) dataset is assigned with a mask set to False
    tmp_masks = torch.zeros(len(train_objects)- train_dataset_size_labeled , dtype=bool)
    train_label_masks = torch.concatenate([train_label_masks,tmp_masks])
    idx = torch.randperm(train_label_masks.shape[0])
    train_label_masks = train_label_masks[idx].view(train_label_masks.size())

assert train_label_masks.shape[0] == len(train_objects)
train_dataset = SemEval_Dataset(train_objects, label_list, train_label_masks,max_seq_length, tokenizer)
# train_dataset = torch.utils.data.Subset(train_dataset, [i for i in range(train_dataset_size)])

In [None]:
train_size = int(0.8 * len(train_objects))  # 80% for training
val_size = len(train_objects) - train_size  # Remaining 20% for validation

train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
test_label_masks = torch.ones(len(dev_objects), dtype=bool)
test_dataset = SemEval_Dataset(dev_objects, label_list, test_label_masks,max_seq_length, tokenizer)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=os.cpu_count(),shuffle=True, drop_last=False)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=os.cpu_count(),shuffle=True, drop_last=False)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=os.cpu_count(),shuffle=True, drop_last=False)

In [None]:
print('Number of train samples: ', len(train_dataset))
print('Number of validation samples: ', len(val_dataset))
print('Number of test samples: ', len(test_dataset))

In [None]:
# test train_dataset
betch = next(iter(train_dataloader))
print(f'input_ids shape: {betch[0].shape}, \ninput_mask shape: {betch[1].shape}, \
        \nlabel_ids shape: {betch[2].shape},\nlabel_mask shape: {betch[3].shape}')

In [None]:
# for batch in train_dataloader:
# #     print(batch[2][batch[3]])
#     print()
#     print(batch[1])
#     break

Create a **Fake_Dataset** :

In [None]:
fake_dataset =  Fake_Dataset(train_objects,256, tokenizer)

In [None]:
noisy_dataloader = DataLoader(dataset=fake_dataset, batch_size=batch_size, num_workers=os.cpu_count(),shuffle=True)

In [None]:
next(iter(noisy_dataloader))[0]

## GAN-BERT

In [None]:
import matplotlib.pyplot as plt

img = plt.imread('/content/drive/My Drive/Project/GAN-BERT.png')
plt.imshow(img);
plt.axis('off');

Hyperparameters:

In [None]:
epoch_num = 4

learning_rate = 5e-4
noise_size = 100
epsilon = 1e-8
warmup_proportion = 0.1  #TODO

In [None]:
# Create the Discriminator and Generator
discriminator = Discriminator(input_size,num_classes).to(device)
generator2 = Generator2(bert_generator).to(device)
bert = BERT_Embedder(bert_model).to(device)


# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    discriminator = nn.DataParallel(discriminator, list(range(ngpu)))
    generator = nn.DataParallel(generator2, list(range(ngpu)))    
    bert = nn.DataParallel(bert, list(range(ngpu)))

# weights initialization   # TODO : Xavier weight initialization
discriminator.apply(custom_weights_init)
# generator1.apply(custom_weights_init)

# print(discriminator)
# print()
# print(generator1)
# print()
# print(bert)

### Define the Optimizers and Scheduler:

In [None]:
gen_optimizer = torch.optim.AdamW(generator2.parameters(), lr=learning_rate)
dis_optimizer = torch.optim.AdamW(list(bert.parameters()) + list(discriminator.parameters()), lr=learning_rate)

#scheduler
num_train_examples = len(train_dataset)
num_train_steps = int(num_train_examples / batch_size * epoch_num) 
num_warmup_steps = int(num_train_steps * warmup_proportion)

scheduler_d = get_constant_schedule_with_warmup(dis_optimizer, 
                                       num_warmup_steps = num_warmup_steps)
scheduler_g = get_constant_schedule_with_warmup(gen_optimizer, 
                                       num_warmup_steps = num_warmup_steps) 

 The loss function of $Discriminator$ is defined as:  $$\quad L_{\mathcal{D}}=L_{\mathcal{D}_{\text {sup. }}}+L_{\mathcal{D}_{\text {unsup. }}}$$
 where:
$$
\begin{aligned}
L_{\mathcal{D}_{\text {sup. }}} & =-\mathbb{E}_{x, y \sim p_d} \log \left[p_{\mathrm{m}}(\hat{y}=y \mid x, y \in(1, \ldots, k))\right] \\
L_{\mathcal{D}_{\text {unsup. }}} & =-\mathbb{E}_{x \sim p_d} \log \left[1-p_{\mathrm{m}}(\hat{y}=y \mid x, y=k+1)\right] -\mathbb{E}_{x \sim \mathcal{G}} \log \left[p_{\mathrm{m}}(\hat{y}=y \mid x, y=k+1)\right] \\
\rightarrow  L_{\mathcal{D}_{\text {unsup. }}} & =-\mathbb{E}_{x \sim p_d} [\log (\mathcal{D}(x))] -\mathbb{E}_{x \sim \mathcal{G}} 
[\log (1-\mathcal{D}(x))]
\end{aligned}
$$

And loss function of $Generator$ is defined as: $$\quad L_{\mathcal{G}}=L_{\mathcal{G}_{\text {feature matching }}}+L_{\mathcal{G}_{\text {unsup. }}}$$ 
where:

$$L_{\mathcal{G}_{\text {unsup. }}}=-\mathbb{E}_{x \sim \mathcal{G}} 
\log \left[1-p_m(\hat{y}=y \mid x, y=k+1)\right]$$

$$ L_{\mathcal{G}_{\text {feature matching }}} = ||\mathbb{E}_{x \sim p_d} f(x) 
- \mathbb{E}_{x \sim \mathcal{G}} f(x) ||_2^2$$


In [None]:
class GANBERT():
    def __init__(self, discriminator, generator, bert,gen_optimizer, dis_optimizer,
                scheduler_d,scheduler_g, path,G2=False): 

        self.discriminator = discriminator
        self.generator = generator
        self.bert = bert
        self.gen_optimizer = gen_optimizer
        self.dis_optimizer = dis_optimizer
        self.scheduler_g = scheduler_g
        self.scheduler_d = scheduler_d
        self.nll_loss = torch.nn.CrossEntropyLoss(ignore_index=-1 , label_smoothing=0.005) # which one
        self.path = path
        self.G2 = G2
 
    def trainer(self, epoch_num, label_list,labeled_ratio,
               train_dataloader, val_dataloader=None, noisy_dataloader=None
                ,report=True):
        
        best_score = 1e-5
    
        def format_time(elapsed):
            '''
            Takes a time in seconds and returns a string hh:mm:ss
            '''
            # Round to the nearest second.
            elapsed_rounded = int(round((elapsed)))
            # Format as hh:mm:ss
            return str(datetime.timedelta(seconds=elapsed_rounded))
        
        results = []
        print(f'With labeled_ratio : {labeled_ratio}\n')
        for epoch in range(epoch_num):
            # Measure how long the each epoch takes.
            t0 = time.time()
            
            self.bert.train()
            self.generator.train()
            self.discriminator.train()

            tr_g_loss = 0
            tr_d_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0

            print(f'Epoch {epoch+1}/{epoch_num} :')
            for step, batch in enumerate(train_dataloader):

                src_input_ids, src_input_mask, label_ids, b_label_mask = batch # unpacking
                src_input_ids = src_input_ids.to(device)
                src_input_mask = src_input_mask.to(device)
                label_ids = label_ids.to(device)
                b_label_mask = b_label_mask.to(device)

                self.bert.zero_grad()
                self.discriminator.zero_grad()

                # Real representations
                embedding = self.bert(src_input_ids, attention_mask=src_input_mask)
                D_real_features, D_real_logits, D_real_probs = self.discriminator(embedding)

                # Random noise
                if self.G2:
                    noisy_input_ids, noisy_input_mask = next(iter(noisy_dataloader))
                    noisy_input_ids = noisy_input_ids.to(device)
                    noisy_input_mask = noisy_input_mask.to(device)
                    gen_rep = self.generator(noisy_input_ids, attention_mask=noisy_input_mask)
                
                else:
                    noise = torch.zeros(src_input_ids.shape[0],noise_size, device=device).uniform_(0, 1)
                    gen_rep = self.generator(noise)
                
                
                ############################
                # Update Generator network: minimize -E[log(D(G(z)))] + feature_matching LOSS
                ###########################
                D_fake_features, D_fake_logits, D_fake_probs = self.discriminator(gen_rep) # .detach()

                g_loss_d = -1 * torch.mean(torch.log(1 - D_fake_probs[:,-1] + epsilon))
                g_feat_reg = torch.mean(torch.pow(torch.mean(D_real_features, dim=0) - torch.mean(D_fake_features, dim=0), 2))
                g_loss = g_loss_d + g_feat_reg

                ############################
                #  Update Discriminator network: minimize -E[log(D(x)) + log(1 - D(G(z)))]
                ###########################
                logits = D_real_logits[:,0:-1]
                log_probs = F.log_softmax(logits, dim=-1)
                # The discriminator provides an output for labeled and unlabeled real data
                # so the loss evaluated for unlabeled data is ignored (masked)
                label2one_hot = torch.nn.functional.one_hot(label_ids, len(label_list))
                per_example_loss = -torch.sum(label2one_hot * log_probs, dim=-1)
                per_example_loss = torch.masked_select(per_example_loss, b_label_mask.to(device))
                labeled_example_count = per_example_loss.type(torch.float32).numel()

                # It may be the case that a batch does not contain labeled examples,
                # so the "supervised loss" in this case is not evaluated
                if labeled_example_count == 0:
                    D_L_Supervised = 0
                else:
                    D_L_Supervised = torch.div(torch.sum(per_example_loss.to(device)), labeled_example_count)

                D_L_unsupervised1U = -1 * torch.mean(torch.log(1 - D_real_probs[:, -1] + epsilon))
                D_L_unsupervised2U = -1 * torch.mean(torch.log(D_fake_probs[:, -1] + epsilon))
                d_loss = D_L_Supervised + D_L_unsupervised1U + D_L_unsupervised2U

                #---------------------------------
                #  OPTIMIZATION
                #---------------------------------
                self.gen_optimizer.zero_grad()
                self.dis_optimizer.zero_grad()

                # Calculate weigth updates
                # retain_graph=True is required since the underlying graph will be deleted after backward
                g_loss.backward(retain_graph=True)
                d_loss.backward() 

                # Apply modifications
                self.gen_optimizer.step()
                self.dis_optimizer.step()

                # Save the losses to print them later
                tr_g_loss += g_loss.item()
                tr_d_loss += d_loss.item()

            # Output training stats
                if report:
                    if step % 100 == 0:
                        print('''\n[Epoch %d/%d][iter %d/%d]\ttotal Loss_D: %.4f\ttotal Loss_G: %.4f,\n
                        details of Loss_D:  Loss_D_sup: %.4f,\t-E[log(D(x))]: %.4f,\t-E[log(1-D(G(z)))]: %.4f,\n
                        details of Loss_G:  -E[log(D(G(z)))]: %.4f,\tLoss_G_feat: %.4f\n
                        D(x): %.4f\tD(G(z)): %.4f'''
                          %(epoch+1, epoch_num, step, len(train_dataloader),
                            d_loss.mean().item(), g_loss.mean().item(), 
                            D_L_Supervised, D_L_unsupervised1U, D_L_unsupervised2U,
                            g_loss_d, g_feat_reg,
                            torch.mean(D_real_probs[:, -1]).item(), 
                              torch.mean(D_fake_probs[:, -1]).item() ))
                        
                        # save checkpoints
                        self.save_checkpoint(epoch)

            # Update the learning rate with the scheduler
            self.scheduler_d.step()
            self.scheduler_g.step()

            # Calculate the average loss over all of the batches.
            avg_train_loss_g = tr_g_loss / len(train_dataloader)
            avg_train_loss_d = tr_d_loss / len(train_dataloader)

            # Measure how long this epoch took.
            epoch_time = format_time(time.time() - t0)

            print("")
            print(f' Training stats at epoch {epoch+1}: ')
            print(f' G_loss = {tr_g_loss}, D_loss = {tr_d_loss} \n')
            print(" Training epcoh took: {:}".format(epoch_time))
            
            if val_dataloader != None:
                self.bert.eval()
                self.discriminator.eval() 

                all_preds = np.array([])
                all_label_ids = np.array([])
                eval_loss = 0
                nb_eval_steps = 0
                for val_step, batch in enumerate(val_dataloader):
                    src_input_ids, src_input_mask, label_ids, _ = batch # unpacking
                    src_input_ids = src_input_ids.to(device)
                    src_input_mask = src_input_mask.to(device)
                    label_ids = label_ids.to(device)


                    with torch.no_grad():
                        doc_rep = self.bert(src_input_ids, attention_mask=src_input_mask)
                        _, logits, _ = self.discriminator(doc_rep)
#                         probs = torch.nn.functional.softmax(logits[:,0:-1], dim=-1)
                        probs = logits[:,0:-1]
                        tmp_eval_loss = self.nll_loss(probs, label_ids.view(-1))

                    eval_loss += tmp_eval_loss.mean().item()

                    probs = probs.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    all_preds = np.append(all_preds, np.argmax(probs, axis=1))
                    all_label_ids = np.append(all_label_ids, label_ids)

                    nb_eval_steps += 1

                eval_loss = eval_loss / nb_eval_steps
#                 precision, recall, f1, _ = precision_recall_fscore_support(all_label_ids, all_preds, average="micro",
#                                                                          labels=list(range(0,len(label_list))))
                mcc = matthews_corrcoef(all_preds, all_label_ids)
                acc = (all_preds == all_label_ids).sum().item() / all_label_ids.shape[0]


                # Output validation stats
                print(f'Validation stats: ')
                print('Loss: %.4f,\tAccuracy: %.4f,\tmcc: %.4f,'
                  %(eval_loss,acc,mcc))
                
            result = {
                'epoch': epoch_time,
                "gen_loss": tr_g_loss,
                "dis_loss": tr_d_loss,
                "eval_loss": eval_loss,
                "mcc": mcc,
                "acc": acc,
                'epoch_time': epoch_time}
#                 "precision_micro": precision,
#                 "recall_micro": recall,
#                 "f1_micro": f1,
                

            results.append(result)
            # save checkpoints
            self.save_checkpoint(epoch,results)
            
            # seva best model
            if acc > best_score:
                best_score = acc 
                self.save_checkpoint(epoch ,result,best=True)
            
    def save_checkpoint(self,epoch,results=None,best=False):
        checkpoint = {
            'epoch': epoch + 1,
            'bert_state_dict': self.bert.state_dict(),
            'disc_state_dict': self.discriminator.state_dict(),
            'gen_state_dict': self.generator.state_dict(),
            'disc_optimizer_state_dict': self.dis_optimizer.state_dict(),
            'gen_optimizer_state_dict': self.gen_optimizer.state_dict(),
            }
        # for colab : /content/drive/My Drive/Project/checkpoints
        if best:
            torch.save(checkpoint, f'{self.path}/GAN_BERT_checkpoint_BEST.pth')
            if results!= None:
                with open(f'{self.path}/results_BEST.pickle', 'wb') as file:
                    pickle.dump(results, file)
            
        else:
            torch.save(checkpoint, f'{self.path}/GAN_BERT_checkpoint{epoch+1}.pth')
            if results!= None:
                with open(f'{self.path}/results.pickle', 'wb') as file:
                    pickle.dump(results, file)
    
    def test(self, test_dataloader):
        self.bert.eval()
        self.discriminator.eval()

        all_preds = np.array([])
        all_label_ids = np.array([])
        eval_loss = 0
        nb_eval_steps = 0
        nll_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
        
        for val_step, batch in enumerate(test_dataloader):
            src_input_ids, src_input_mask, label_ids, _ = batch # unpacking
            src_input_ids = src_input_ids.to(device)
            src_input_mask = src_input_mask.to(device)
            label_ids = label_ids.to(device)


            with torch.no_grad():
                doc_rep = self.bert(src_input_ids, attention_mask=src_input_mask)
                _, logits, _ = self.discriminator(doc_rep)
            # probs = torch.nn.functional.softmax(logits[:,0:-1], dim=-1)
            probs = logits[:,0:-1]    
            tmp_eval_loss = nll_loss(probs, label_ids.view(-1))

            eval_loss += tmp_eval_loss.mean().item()

            probs = probs.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            
            all_preds = np.append(all_preds, np.argmax(probs, axis=1))
            all_label_ids = np.append(all_label_ids, label_ids)

            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps


        mcc = matthews_corrcoef(all_preds, all_label_ids)
        acc = (all_preds == all_label_ids).sum().item() / all_label_ids.shape[0]
        
        # Output validation stats
        print(f'Test stats: ')
        print('Total loss: %.4f,\tAccuracy: %.4f,\tmcc: %.4f,'
          %(eval_loss,acc,mcc) ) 
        return all_preds

    @staticmethod
    def rename_keys(original_ordered_dict):
        new_keys_mapping = dict()
        for a in list(original_ordered_dict.keys()):
            new_keys_mapping[a] = a.split('module.')[-1] 

        return OrderedDict((new_keys_mapping.get(k, k), v) for k, v in original_ordered_dict.items())

    
    def load_checkpoint(self,checkpoint_path):
        state_dict = torch.load(checkpoint_path)
        
        if (device.type == 'cuda') and (ngpu > 1):
            # Load the state dictionary into the model
            self.bert.load_state_dict(state_dict['bert_state_dict'])
            self.discriminator.load_state_dict(state_dict['disc_state_dict'])
            self.generator.load_state_dict(state_dict['gen_state_dict'])
            self.dis_optimizer.load_state_dict(state_dict['disc_optimizer_state_dict'])
            self.gen_optimizer.load_state_dict(state_dict['gen_optimizer_state_dict'])
            
        else: 
            self.bert.load_state_dict(self.rename_keys(state_dict['bert_state_dict']))
            self.discriminator.load_state_dict(self.rename_keys(state_dict['disc_state_dict']))
            self.generator.load_state_dict(self.rename_keys(state_dict['gen_state_dict']))
            self.dis_optimizer.load_state_dict(state_dict['disc_optimizer_state_dict'])
            self.gen_optimizer.load_state_dict(state_dict['gen_optimizer_state_dict'])

        print('Loaded !')
            
    def plot_results():
        pass
            
    def show_tensorboard():
        pass

In [None]:
# !pip install numba

# from numba import cuda
# device = cuda.get_current_device()
# device.reset()
# torch.cuda.empty_cache()  

In [None]:
# !ls
!mkdir part3
!ls

----
## With **G1** :

In [None]:
ganbert = GANBERT(discriminator, generator1, bert,gen_optimizer, dis_optimizer,
                scheduler_d,scheduler_g, path='/kaggle/working/part3') 

In [None]:
ganbert.trainer(epoch_num,label_list,labeled_ratio,train_dataloader, val_dataloader,report=True)

In [None]:
test_res = ganbert.test(test_dataloader) 

The `Matthews correlation coefficient` , is a measure of the quality of classifications in machine learning. It takes into account true and false positives and negatives and is generally regarded as a balanced measure which can be used even if the classes are of very different sizes. It's defined in the range from -1 to 1, with 1 being a perfect prediction, 0 being the result of a random prediction, and -1 indicating total disagreement between prediction and observation.

---
### Load the best model

In [None]:
discriminator = Discriminator(input_size,num_classes).to(device)
generator1 = Generator1(noise_size,input_size).to(device)
bert = BERT_Embedder(bert_model).to(device)

if (device.type == 'cuda') and (ngpu > 1):
    discriminator = nn.DataParallel(discriminator, list(range(ngpu)))
    generator = nn.DataParallel(generator1, list(range(ngpu)))    
    bert = nn.DataParallel(bert, list(range(ngpu)))
    
gen_optimizer = torch.optim.AdamW(generator1.parameters(), lr=learning_rate)
dis_optimizer = torch.optim.AdamW(list(bert.parameters()) + list(discriminator.parameters()), lr=learning_rate)

#scheduler
num_train_examples = len(train_dataset)
num_train_steps = int(num_train_examples / batch_size * epoch_num) 
num_warmup_steps = int(num_train_steps * warmup_proportion)

scheduler_d = get_constant_schedule_with_warmup(dis_optimizer, 
                                       num_warmup_steps = num_warmup_steps)
scheduler_g = get_constant_schedule_with_warmup(gen_optimizer, 
                                       num_warmup_steps = num_warmup_steps) 

ganbert_best = GANBERT(discriminator, generator1, bert,gen_optimizer, dis_optimizer,
                scheduler_d,scheduler_g, path='/kaggle/working/part3') 

In [None]:
path = '/kaggle/working/part3/GAN_BERT_checkpoint_BEST.pth'
ganbert_best.load_checkpoint(path)

In [None]:
test_res = ganbert_best.test(test_dataloader) 

----
## With **G2** :

In [None]:
ganbert2 = GANBERT(discriminator, generator2, bert,gen_optimizer, dis_optimizer,
                scheduler_d,scheduler_g, path='/kaggle/working/part3', G2=True) 


ganbert2.trainer(epoch_num,label_list,labeled_ratio,train_dataloader, val_dataloader,noisy_dataloader,report=True)
# test_res = ganbert.test(test_dataloader) 

In [47]:
ganbert2.load_checkpoint('/kaggle/working/part3/GAN_BERT_checkpoint_BEST.pth')

Loaded !


In [None]:
ganbert2.trainer(3,label_list,labeled_ratio,train_dataloader, val_dataloader,noisy_dataloader,report=True)

With labeled_ratio : 0.5

Epoch 1/3 :

[Epoch 1/3][iter 0/888]	total Loss_D: 1.3547	total Loss_G: 0.7173,

                        details of Loss_D:  Loss_D_sup: 0.6033,	-E[log(D(x))]: 0.0178,	-E[log(1-D(G(z)))]: 0.7336,

                        details of Loss_G:  -E[log(D(G(z)))]: 0.6742,	Loss_G_feat: 0.0432

                        D(x): 0.0175	D(G(z)): 0.4857

[Epoch 1/3][iter 100/888]	total Loss_D: 1.7101	total Loss_G: 0.7100,

                        details of Loss_D:  Loss_D_sup: 0.9509,	-E[log(D(x))]: 0.0151,	-E[log(1-D(G(z)))]: 0.7441,

                        details of Loss_G:  -E[log(D(G(z)))]: 0.6662,	Loss_G_feat: 0.0438

                        D(x): 0.0150	D(G(z)): 0.4811

[Epoch 1/3][iter 200/888]	total Loss_D: 1.2740	total Loss_G: 0.7294,

                        details of Loss_D:  Loss_D_sup: 0.5329,	-E[log(D(x))]: 0.0169,	-E[log(1-D(G(z)))]: 0.7242,

                        details of Loss_G:  -E[log(D(G(z)))]: 0.6867,	Loss_G_feat: 0.0427

                        

In [48]:
test_res = ganbert2.test(test_dataloader) 

Test stats: 
Total loss: 1.4534,	Accuracy: 0.5413,	mcc: 0.4610,


---

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()