In [1]:
import torch
import pandas as pd
from transformers import Trainer, TrainingArguments
from torch.optim import AdamW
from transformers import get_scheduler
import torch.nn as nn
from src.load_dataloader import initial_dataloader_vector_slicing
from src.evaluation import initial_LSTM
from src.load_config import load_config
from src.evaluation import show_sentence
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# import tqdm
from tqdm import tqdm

batch_size = 8
max_length = 128
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


In [2]:
model, tokenizer, device = load_config(max_length)
# import data
df = pd.read_csv('data_preprocess/datasets_combine.csv')
train_df, val_df,test_df, train_dataset,val_dataset, test_dataset,train_loader, val_loader, test_loader = initial_dataloader_vector_slicing(df, tokenizer, max_length,batch_size)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
LSTM_model = initial_LSTM(tokenizer, device)

In [4]:
def classify_sentence(predicted_token_ids, lstm_model):
    # convert predicted_token_ids to text first
    texts1 = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
    # convert back to token ids
    predicted_token_ids = tokenizer(texts1, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids'].to(device)
    
    #predicted_token_ids = tokenizer(texts1, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids'].to(device)

    # Initialize a tensor to store LSTM outputs
    lstm_outputs = torch.empty(size=(predicted_token_ids.shape[0], 2)).to(device)  # 2 for binary classification

    # Process each item in the batch
    for idx, token_ids in enumerate(predicted_token_ids):
        lstm_model.eval()
        with torch.no_grad():
            # Assuming lstm_model.predict returns a tensor of shape [1, 2] (probability for each class)
            lstm_output = lstm_model(token_ids.flatten())  # Modify this call according to your LSTM model's interface
            lstm_outputs[idx] = lstm_output

    return lstm_outputs

In [5]:
optimizer = AdamW(model.parameters(), lr=1e-5)
num_epochs = 10
model.to(device)
loss_classification = nn.CrossEntropyLoss()

for epoch in tqdm(range(num_epochs)):
    model.train()
    
    for batch in train_loader:
        input_ids1 = batch['input_ids1'].to(device)
        attention_mask1 = batch['attention_mask1'].to(device)
        labels1 = batch['labels1'].to(device)
        
        input_ids2 = batch['input_ids2'].to(device)
        attention_mask2 = batch['attention_mask2'].to(device)
        labels2 = batch['labels2'].to(device)
        
        # first reconstruct the two sentences
        
        outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1, labels=labels1)
        outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2, labels=labels2)
        
        loss = outputs1.loss + outputs2.loss
        
        # ----------------------- perform style transfer -----------------------
        # call model encoder
        encoder_outputs1 = model.encoder(input_ids=input_ids1)
        latent_vector1 = encoder_outputs1.last_hidden_state
        
        encoder_outputs2 = model.encoder(input_ids=input_ids2)
        latent_vector2 = encoder_outputs2.last_hidden_state

        # modify latent vector
        latent_vector1_content = latent_vector1[:, :, :384]
        latent_vector1_style = latent_vector1[:, :, 384:]
        
        latent_vector2_content = latent_vector2[:, :, :384]
        latent_vector2_style = latent_vector2[:, :, 384:]
        
        # swap style
        modify_latent_vector1 = torch.cat([latent_vector1_content, latent_vector2_style], dim=-1)
        modify_latent_vector2 = torch.cat([latent_vector2_content, latent_vector1_style], dim=-1)
        
        encoder_outputs1.last_hidden_state = modify_latent_vector1
        encoder_outputs2.last_hidden_state = modify_latent_vector2
        
        outputs1 = model(decoder_input_ids = input_ids1, encoder_outputs=encoder_outputs1)
        outputs2 = model(decoder_input_ids = input_ids2, encoder_outputs=encoder_outputs2)

        logits1 = outputs1.logits
        predicted_token_ids1 = torch.argmax(logits1, dim=-1).to(device)
        transfer_labels1 = batch['sentence2_style'].to(device)
        
        logits2 = outputs2.logits
        predicted_token_ids2 = torch.argmax(logits2, dim=-1).to(device)
        transfer_labels2 = batch['sentence1_style'].to(device)
        
        # sent to LSTM model
        lstm_outputs1 = classify_sentence(predicted_token_ids1,LSTM_model)
        lstm_outputs2 = classify_sentence(predicted_token_ids2, LSTM_model)  
        
        # classification loss
        style_loss1 = loss_classification(lstm_outputs1,transfer_labels1)
        style_loss2 = loss_classification(lstm_outputs2,transfer_labels2)
        loss += style_loss1 + style_loss2
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    

  'input_ids1': torch.tensor(item['input_ids1'], dtype=torch.long).squeeze(0),
  'input_ids2': torch.tensor(item['input_ids2'], dtype=torch.long).squeeze(0),
  'attention_mask1': torch.tensor(item['attention_mask1'], dtype=torch.long).squeeze(0),
  'attention_mask2': torch.tensor(item['attention_mask2'], dtype=torch.long).squeeze(0),
  'labels1': torch.tensor(item['labels1'], dtype=torch.long).squeeze(0),
  'labels2': torch.tensor(item['labels2'], dtype=torch.long).squeeze(0),
 10%|█         | 1/10 [10:45<1:36:52, 645.88s/it]

In [None]:
for index in range(20):
    
    text_pair = test_df.iloc[index]
    text1 = text_pair['sentence']
    text2 = text_pair['target_text']

    text_id1 = tokenizer(text1, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids'].to(device)
    text_id2 = tokenizer(text2, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids'].to(device)

    encoder_outputs1 = model.encoder(input_ids=text_id1)
    latent_vector1 = encoder_outputs1.last_hidden_state

    encoder_outputs2 = model.encoder(input_ids=text_id2)
    latent_vector2 = encoder_outputs2.last_hidden_state

    # modify latent vector
    latent_vector1_content = latent_vector1[:, :, :384]
    latent_vector1_style = latent_vector1[:, :, 384:]

    latent_vector2_content = latent_vector2[:, :, :384]
    latent_vector2_style = latent_vector2[:, :, 384:]

    # swap style
    modify_latent_vector1 = torch.cat([latent_vector1_content, latent_vector2_style], dim=-1)
    modify_latent_vector2 = torch.cat([latent_vector2_content, latent_vector1_style], dim=-1)

    outputs1 = model(decoder_input_ids = text_id1, encoder_outputs=encoder_outputs1)
    outputs2 = model(decoder_input_ids = text_id2, encoder_outputs=encoder_outputs2)

    logits1 = outputs1.logits
    predicted_token_ids1 = torch.argmax(logits1, dim=-1).to(device).flatten()

    logits2 = outputs2.logits
    predicted_token_ids2 = torch.argmax(logits2, dim=-1).to(device).flatten()

    new_text1 = tokenizer.decode(predicted_token_ids1, skip_special_tokens=True)
    new_text2 = tokenizer.decode(predicted_token_ids2, skip_special_tokens=True)

    print("original text1: ", text1)
    print("transfer text1: ", new_text1)
    print("original text2: ", text2)
    print("transfer text2: ", new_text2)
    print("--------------------------------------------------")

original text1:  I'll not call you tyrant;But this most cruel usage of your queen,Not able to produce more accusationThan your own weak-hinged fancy, something savoursOf tyranny and will ignoble make you,Yea, scandalous to the world
transfer text1:  I not call you ty this most  usage of your queen, able to produce more accusorThan your own -hinged fancy, something tavoursOf tyranny and will ignoble make you,Jea  scandalous to the world.
original text2:  Investigation focused on pilots reactions to instrument failure, autopilot switching off
transfer text2:  focused on ons reactions to instrument failure, autopilot switching off.Wssss
--------------------------------------------------
original text1:  Jordan Silverstone bought the suit for just £125 from charity shop
transfer text1:  Jordanstone from for for from just $225 from the shopssscomings offer just
original text2:  First Herald:Harry of Hereford, Lancaster and Derby,Stands here for God, his sovereign and himself,On pain to be f

In [None]:
# save model
model.save_pretrained('./T5_model_sliding')
#torch.save(model.state_dict(), './model_save/T5_model.pth')

In [None]:
# READ model
# model = T5ForConditionalGeneration.from_pretrained('./model_save')