#### This file will finetune a model using a GAN setup

### Setup

In [None]:
!pip install datasets
!pip install transformers[sentencepiece]
!pip install --upgrade transformers
!pip install einops
!pip install torch
!pip install huggingface_hub
!pip install sentencepiece

In [22]:
from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset, DatasetDict

import pandas as pd
import numpy as np

from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AdamW
from huggingface_hub import login

import torch
import torch.nn as nn
import torch.optim as optim

import os
import re
import time


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
dataset = DatasetDict.load_from_disk("dataset")

In [13]:
#Returns an equal number of positive and negative samples. Ex: samples = 1000 means 500 positive and 500 negative datapoints

def sampleDataset(samples):
    train_dataset = dataset['train']

    dataset_sentiment_0 = train_dataset.filter(lambda x: x['label'] == 0)
    dataset_sentiment_1 = train_dataset.filter(lambda x: x['label'] == 1)

    dataset_sampled_0 = dataset_sentiment_0.shuffle().select(range(samples//2))
    dataset_sampled_1 = dataset_sentiment_1.shuffle().select(range(samples//2))

    dataset_combined = concatenate_datasets([dataset_sampled_0, dataset_sampled_1])

    sampled_dataset = dataset_combined.shuffle()

    print("Positive: ", sum(1 for example in sampled_dataset if example['label'] == 1))
    print("Negative: ", sum(1 for example in sampled_dataset if example['label'] == 0))
    print(sampled_dataset)
    
    return sampled_dataset


In [11]:
#For the purposes of this GAN we only trained with 1000 datapoints
sampled_dataset = sampleDataset(1000)

Filter:   0%|          | 0/1600000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1600000 [00:00<?, ? examples/s]

Positive:  500
Negative:  500
Dataset({
    features: ['text', 'label'],
    num_rows: 1000
})


In [9]:
positivePrompt = '''
Generate a positive social media tweet on a specific topic. Ensure your tweet is enclosed in straight double quotation marks. Provide ONLY one tweet. The positive tweet should express enthusiasm or praise. 

Positive: "'''

In [10]:
negativePrompt = '''
Generate a negative social media tweet on a specific topic. Ensure your tweet is enclosed in straight double quotation marks and separated by a colon. Provide ONLY one tweet. The negative tweet should express convey criticism or disappointment. 

Negative: "'''

In [14]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
login(YOUR_KEY_HERE)

model_name = "microsoft/Phi-3-mini-4k-instruct"
# model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

torch.cuda.empty_cache()

if(torch.cuda.is_available()):
  model = model.cuda()

if tokenizer.pad_token is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

tokenizer_config.json:   0%|          | 0.00/3.44k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/599 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/16.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

### Model Prompting

A major issue we ran into here was the model repeating the instructions back in it's output. Also, the model tends to include addtional unwanted information after the requested tweet. Obviously this isn't ideal because it shouldn't be part of what the generator and discriminator learn on. 

The problem with simply modifying the string output is that it doesn't show the logits (this is what the GAN setup actually uses). To solve this, we created a two step approach: 

1. Call promptModel() with an intial length of 100 tokens.
2. Filter out the tweet from the unwanted content using findTweets()
3. Feed the tweet back into the model using promptModel() but set a length of the tweet.

Now, the model is forced to just return the tweet back (no space for additonal content), but also gives the correct logits. 

In [15]:
#Returns the text and logits of a generated tweet. 
def promptModel(prompt, length):
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    attention_mask = inputs['attention_mask'].to(device)

    outputs = model.generate(
        inputs['input_ids'],
        do_sample=True,
        attention_mask=attention_mask,
        num_return_sequences=1,
        max_length=length,
        temperature = 0.9, 
        top_p = 0.9, 
        repetition_penalty=1.2,
        output_scores=True,
        return_dict_in_generate=True,
        pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
    )

    # Decode the generated text
    text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)

    with torch.no_grad():
        # Forward pass to get logits
        logits = model(outputs.sequences).logits


    return text, logits

In [None]:
#Uses regex to find the actual tweet content of a tweet. 
def findTweets(generated_text, isPositive):
    output = generated_text.replace('“', '"').replace('”', '"')

    if(isPositive):
        positive_match = re.search(r'Positive:\s*"\s*([^"]*)\s*"', output, re.DOTALL)
        positive_tweet = positive_match.group(1).strip() if positive_match else "-1"
        return positive_tweet

    else:
        negative_match = re.search(r'Negative:\s*"\s*([^"]*)\s*"', output, re.DOTALL)
        negative_tweet = negative_match.group(1).strip() if negative_match else "-1"
        return negative_tweet

In [16]:
#Return a tweet and it's logits using the two step filter. 
def generateAndExtractTweets(prompt, label, length):
            
    while True:
        text, logits = promptModel(prompt, length)
        tweet = findTweets(text, label == 1)
        
        if tweet != "-1" and len(tweet) > 0:
            length = len(tokenizer.encode(tweet, add_special_tokens=True))
            text, logits = promptModel(tweet, length + 1)  
            return text, logits

### GAN 

Our GAN uses 1000 datapoints. You can change this amount by updating the sampled_dataset size in the Setup section. 

We train for 5 epochs. You can change this amount by changing the epochs variable on the training loop

In [19]:
#Our Discriminator class
class Discriminator(nn.Module):
  def __init__(self, vocab_size, embed_size):
    super(Discriminator, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.lstm = nn.LSTM(embed_size, 128, batch_first=True)
    self.fc = nn.Linear(128, 1)
    self.sigmoid = nn.Sigmoid()

  def forward(self, input_ids):
    embeds = self.embedding(input_ids)
    _, (hidden, _) = self.lstm(embeds)
    output = self.fc(hidden[-1])
    return self.sigmoid(output)


In [None]:
#Instantiate the generator
generator = model

#Instantiate the Discriminator
if(torch.cuda.is_available()):
  discriminator = Discriminator(len(tokenizer), 768).cuda()
else:
    discriminator = Discriminator(len(tokenizer), 768)


#Optimizers
optimizerG = AdamW(generator.parameters(), lr=5e-5)
optimizerD = optim.Adam(discriminator.parameters(), lr=5e-5)

#Loss function
criterion = nn.BCEWithLogitsLoss()

#Dataloader
data_loader = torch.utils.data.DataLoader(sampled_dataset, batch_size=1, shuffle=True)

In [None]:
# For more accurate error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ["PYTORCH_USE_CUDA_DSA"] = "1"

batchCount = 1
epochs = 5
for epoch in range(epochs):
    for batch in data_loader:
        
        # Convert the input batch into a PyTorch tensor and move to GPU
        padded_inputs = tokenizer(batch['text'], padding=True, return_tensors="pt", truncation=True)
        real_text = padded_inputs['input_ids'].cuda()        

        #Generate FAKE outputs
        promptUsed = positivePrompt if batch["label"] == 1 else negativePrompt # Use the positive prompt if the batch is positive. 
        labelUsed = batch["label"] 
        fake_text, fake_logits_raw = generateAndExtractTweets(promptUsed, labelUsed, 100) # Get the logits from the tweet
        fake_logits_raw = fake_logits_raw.cuda()
        fake_logits = fake_logits_raw.argmax(dim=-1).detach().cuda()
        
        # Reset gradients for discriminator
        discriminator.zero_grad()

        # Create real labels tensor and move to GPU
        real_labels = torch.ones((real_text.size(0), 1), dtype=torch.float).cuda()

        # Get discriminator's prediction on real text and calculate loss
        real_output = discriminator(real_text)
        lossD_real = criterion(real_output.view(-1, 1), real_labels)
        
        # Reshape fake logits to ensure correct shape
        fake_logits = fake_logits.view(-1, 1).cuda()

        # Get discriminator's prediction on fake text and calculate loss
        fake_output = discriminator(fake_logits)
        
        # Create fake labels tensor and move to GPU
        fake_labels = torch.zeros((fake_output.size(0), 1), dtype=torch.float).cuda()
        
        lossD_fake = criterion(fake_output, fake_labels)

        # Combine real and fake losses for the discriminator
        lossD = lossD_real + lossD_fake
        lossD.backward()

        #Update the Discriminator
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
        optimizerD.step()

        # Reset gradients for generator
        generator.zero_grad()

        # Generate new fake text logits
        with torch.no_grad():
            fake_text, fake_logits_raw = generateAndExtractTweets(promptUsed, labelUsed, 100)
            fake_logits_raw = fake_logits_raw.cuda()
            
        fake_logits = fake_logits_raw.argmax(dim=-1).detach().cuda()

        # Get discriminator's assessment of the newly generated fake data
        fake_output = discriminator(fake_logits.view(-1, 1).cuda())
        real_labels = torch.ones((fake_output.size(0), 1), dtype=torch.float).cuda()

        # Calculate the generator's loss
        lossG = criterion(fake_output.view(-1, 1), real_labels)
        lossG.backward()

        # Update the Generator 
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        optimizerG.step()

        # Print status and clear memory
        print(f'Epoch [{epoch}/{epochs}] Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}')
        print("BATCH NUMBER " + str(batchCount))
        batchCount+=1
        del padded_inputs, real_text, fake_logits, fake_logits_raw, real_labels, real_output, fake_labels, fake_output
        torch.cuda.empty_cache()

Save the GAN. 

Provide YOUR_PATH


In [31]:
save_dir = YOUR_PATH
epoch = YOUR_EPOCHS


if not os.path.exists(save_dir):
    os.makedirs(save_dir)

model_paths = {
    'generator': 'generator.pth',
    'discriminator': 'discriminator.pth',
    'optimizerG': 'optimizerG.pth',
    'optimizerD': 'optimizerD.pth',
    'checkpoint': 'checkpoint.pth'
}

for key in model_paths:
    model_paths[key] = os.path.join(save_dir, model_paths[key])

# Save the generator and discriminator state
torch.save(generator.state_dict(), model_paths['generator'])
torch.save(discriminator.state_dict(), model_paths['discriminator'])

# # Save the optimizer states
torch.save(optimizerG.state_dict(), model_paths['optimizerG'])
torch.save(optimizerD.state_dict(), model_paths['optimizerD'])

# Optionally, save the epoch number or other states
checkpoint = {
    'epoch': epoch,
    'lossG': lossG,  # Current generator loss
    'lossD': lossD   # Current discriminator loss
}
torch.save(checkpoint, model_paths['checkpoint'])
