**To Do List**

1. Dataloader
2. Implementation of the equation 5 for the variable `weight_classes_dataset` defined in the **Inputs** section.
3. Implementation of the DistilGPT2 and the BERT modules
4. Evaluation methods (e.g., Quadratic Weighted Kappa (QWK))

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
manualSeed = 0
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results

## Inputs




In [None]:
# Number of training epochs
num_epochs = 10

# Learning rate for optimizers
lr = 0.0002

# number of classes in the dataset
n_class_dataset = 3

# for the weighted cross entropy loss
weight_classes_dataset = torch.tensor(np.ones(n_class_dataset, dtype=np.float32))

# Input dimension of G1 network, i.e., size of the input vector z
d_in = 100

# Output dimension of G1 network
d_out = 768

# Dropout parameter
p_dropout = 0.5

# beta1 and beta2 for the ADAM optimizers
betas_ADAM = (0.9, 0.999) # Note: no values reported in the paper

# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# lambda values in the loss functions
lambda_score = 1
lambda_feature_matching = 1

## Data




## Generator




In [None]:
# Generator G1, see figure 2a
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input: Z
            nn.Linear(d_in, d_out),
            nn.LeakyReLU(0.2),
            nn.Dropout(p=p_dropout),
            nn.Linear(d_out, d_out)
            # output: v_G
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Create the generator
netG = Generator().to(device)

# netG.apply(weights_init)

# Print the model
print(netG)

## Discriminator




In [None]:
# Discriminator D, see figure 2d
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.seq1 = nn.Sequential(
            # input: v_G or v_B
            nn.Dropout(p=p_dropout),
            nn.Linear(d_out, d_out))
        self.seq2 = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Dropout(p=p_dropout),
            nn.Linear(d_out, 1 + n_class_dataset), # +1 for the probability of this sample being fake/real.
            # output: logits, format: [fake score, dataset classes]
        )
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input):
        features = self.seq1(input) # required for the feature matching loss
        logits = self.seq2(features)
        probs = self.softmax(logits)
        return features, logits, probs

In [None]:
# Create the Discriminator
netD = Discriminator().to(device)

# netD.apply(weights_init)

# Print the model
print(netD)

## Training




In [None]:
criterionGAN = nn.BCEWithLogitsLoss()
criterionScore = nn.CrossEntropyLoss(weight=weight_classes_dataset)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=betas_ADAM)
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=betas_ADAM)

# Establish convention for real and fake labels during training
real_label = 0.
fake_label = 1.

In [None]:
# Training Loop

# Lists to keep track of progress
D_losses = []
G_losses = []
iters = 0

print("Starting Training Loop ...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network
        ############################
        ## Train with all-real batch
        netD.zero_grad()
        # data is a list of [v_B, y_labels]
        v_B = data[0].to(device) # output of the BERT module for real essays (CLS hidden state)
        y_labels = data[1].to(device) # true class labels
        b_size = v_B.size(0)
        labelGAN = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        _, logits, _ = netD(v_B)
        loss_D_real = criterionGAN(logits[:, 0], labelGAN)
        loss_D_score = criterionScore(logits[:, 1:], y_labels)

        ## Train with all-fake batch
        # Generate a batch of latent vectors
        input_noise = torch.randn(b_size, di, 1, 1, device=device)
        v_G = netG(input_noise)
        labelGAN.fill_(fake_label)
        _, logits, _ = netD(v_G.detach())
        loss_D_fake = criterionGAN(logits[:, 0], labelGAN)
        loss_D_real_and_fake = loss_D_real + loss_D_fake

        loss_D_total = loss_D_real_and_fake + lambda_score*loss_D_score
        loss_D_total.backward() # Calculate gradients for D
        optimizerD.step() # Update D

        ############################
        # (2) Update G network
        ############################
        netG.zero_grad()
        labelGAN.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass
        features_real, _, _ = netD(v_B)
        features_fake, logits, _ = netD(v_G)
        loss_G_caught = criterionGAN(logits[:, 0], labelGAN)
        loss_G_feature_matching = torch.mean(torch.square(torch.mean(features_real, dim=0) - torch.mean(features_fake, dim=0)))

        loss_G_total = loss_G_caught + lambda_feature_matching*loss_G_feature_matching
        loss_G_total.backward() # Calculate gradients for G
        optimizerG.step() # Update G

        ############################
        # Training statistics
        ############################
        if i % 10 == 0:
            print('[%3d/%3d][%3d/%3d]\tLoss_D: %.4f\tLoss_G: %.4f'
                  % (epoch, num_epochs, i, len(dataloader), loss_D_total.item(), loss_G_total.item()))

        # Save Losses for plotting later
        D_losses.append(loss_D_total.item())
        G_losses.append(loss_G_total.item())

        iters += 1
print("--------------------------------------------------------------------------------------------")

Training Loss


In [None]:
plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()