In [1]:
import torch
import pandas as pd
from transformers import Trainer, TrainingArguments
from torch.optim import AdamW
from transformers import get_scheduler
import torch.nn as nn
from src.load_dataloader import initial_dataloader_vector_slicing
from src.evaluation import initial_LSTM
from src.load_config import load_config
from src.evaluation import show_sentence
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# import tqdm
from tqdm import tqdm

batch_size = 8
max_length = 128
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


In [2]:
model, tokenizer, device = load_config(max_length)
# import data
df = pd.read_csv('data_preprocess/datasets_combine.csv')
# df = df[:1000]
train_df, val_df,test_df, train_dataset,val_dataset, test_dataset,train_loader, val_loader, test_loader = initial_dataloader_vector_slicing(df, tokenizer, max_length,batch_size)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
LSTM_model = initial_LSTM(tokenizer, device)

In [4]:
def custom_loss_function(texts1, transfer_labels, lstm_model):
    # Assuming transfer_labels is a tensor of shape [batch_size]
    # and predicted_token_ids is of shape [batch_size, seq_len]
    predicted_token_ids = tokenizer(texts1, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids'].to(device)
    
    # Initialize CrossEntropyLoss
    criterion = nn.CrossEntropyLoss()

    # Initialize a tensor to store LSTM outputs
    lstm_outputs = torch.empty(size=(predicted_token_ids.shape[0], 2)).to(device)  # 2 for binary classification

    # Process each item in the batch
    for idx, token_ids in enumerate(predicted_token_ids):
        lstm_model.eval()
        with torch.no_grad():
            # Assuming lstm_model.predict returns a tensor of shape [1, 2] (probability for each class)
            lstm_output = lstm_model(token_ids.flatten())  # Modify this call according to your LSTM model's interface
            lstm_outputs[idx] = lstm_output

    # Compute the loss
    loss = criterion(lstm_outputs, transfer_labels)

    return loss, lstm_outputs, transfer_labels

In [5]:
optimizer = AdamW(model.parameters(), lr=1e-5)
num_epochs = 10
model.to(device)

for epoch in tqdm(range(num_epochs)):
    model.train()
    
    for batch in train_loader:
        input_ids1 = batch['input_ids1'].to(device)
        attention_mask1 = batch['attention_mask1'].to(device)
        labels1 = batch['labels1'].to(device)
        
        input_ids2 = batch['input_ids2'].to(device)
        attention_mask2 = batch['attention_mask2'].to(device)
        labels2 = batch['labels2'].to(device)
        
        # first reconstruct the two sentences
        
        outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1, labels=labels1)
        outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2, labels=labels2)
        
        loss = outputs1.loss + outputs2.loss
        
        # ----------------------- perform style transfer -----------------------
        # call model encoder
        encoder_outputs1 = model.encoder(input_ids=input_ids1)
        latent_vector1 = encoder_outputs1.last_hidden_state
        
        encoder_outputs2 = model.encoder(input_ids=input_ids2)
        latent_vector2 = encoder_outputs2.last_hidden_state

        # modify latent vector
        latent_vector1_content = latent_vector1[:, :, :384]
        latent_vector1_style = latent_vector1[:, :, 384:]
        
        latent_vector2_content = latent_vector2[:, :, :384]
        latent_vector2_style = latent_vector2[:, :, 384:]
        
        # swap style
        modify_latent_vector1 = torch.cat([latent_vector1_content, latent_vector2_style], dim=-1)
        modify_latent_vector2 = torch.cat([latent_vector2_content, latent_vector1_style], dim=-1)
        
        encoder_outputs1.last_hidden_state = modify_latent_vector1
        encoder_outputs2.last_hidden_state = modify_latent_vector2
        
        decoder_input_text = "style transfer:"
        decoder_input_id = tokenizer(decoder_input_text, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids'].to(device)
        # sahpe of decoder_input_id should be [batch_size, seq_max_len]
        decoder_input_ids = decoder_input_id.repeat(input_ids1.shape[0], 1)
        
        outputs1 = model(decoder_input_ids = decoder_input_ids, encoder_outputs=encoder_outputs1)
        outputs2 = model(decoder_input_ids = decoder_input_ids, encoder_outputs=encoder_outputs2)

        logits1 = outputs1.logits
        predicted_token_ids1 = torch.argmax(logits1, dim=-1).to(device)
        transfer_labels1 = batch['sentence2_style'].to(device)
        
        logits2 = outputs2.logits
        predicted_token_ids2 = torch.argmax(logits2, dim=-1).to(device)
        transfer_labels2 = batch['sentence1_style'].to(device)
        
        # sent to LSTM model

        texts1 = tokenizer.batch_decode(predicted_token_ids1, skip_special_tokens=True)
        loss1,_,_ = custom_loss_function(texts1, transfer_labels1, LSTM_model)
            
        texts2 = tokenizer.batch_decode(predicted_token_ids2, skip_special_tokens=True)
        loss2,_,_ = custom_loss_function(texts2, transfer_labels2, LSTM_model)  

        # Accumulate the losses
        #print("-----------------------")
        #print(loss.item())
        loss += loss1.detach() + loss2.detach()
        #print(loss.item())
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # evaluate
    model.eval()
    total_eval_accuracy = 0
    total_eval_loss = 0
    nb_eval_steps = 0
    y_true = []
    y_pred = []
    for batch in val_loader:
        input_ids1 = batch['input_ids1'].to(device)
        attention_mask1 = batch['attention_mask1'].to(device)
        labels1 = batch['labels1'].to(device)
        
        input_ids2 = batch['input_ids2'].to(device)
        attention_mask2 = batch['attention_mask2'].to(device)
        labels2 = batch['labels2'].to(device)
        
        with torch.no_grad():
            outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1, labels=labels1)
            outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2, labels=labels2)
            
            loss = outputs1.loss + outputs2.loss
            
            logits1 = outputs1.logits
            logits2 = outputs2.logits
            
            predicted_token_ids1 = torch.argmax(logits1, dim=-1).to(device)
            predicted_token_ids2 = torch.argmax(logits2, dim=-1).to(device)
            
            transfer_labels1 = batch['sentence2_style'].to(device)
            transfer_labels2 = batch['sentence1_style'].to(device)
            
            texts1 = tokenizer.batch_decode(predicted_token_ids1, skip_special_tokens=True)
            texts2 = tokenizer.batch_decode(predicted_token_ids2, skip_special_tokens=True)
            
            loss1, lstm_outputs1, transfer_labels1 = custom_loss_function(texts1, transfer_labels1, LSTM_model)
            loss2, lstm_outputs2, transfer_labels2 = custom_loss_function(texts2, transfer_labels2, LSTM_model)
            
            total_lstm_loss = loss1 + loss2
            
            loss += total_lstm_loss
            
            total_eval_loss += loss.item()
            
            # get accuracy by using argmax 
            predicted_labels1 = torch.argmax(lstm_outputs1, dim=-1)
            predicted_labels2 = torch.argmax(lstm_outputs2, dim=-1)
            
            total_eval_accuracy += (predicted_labels1 == transfer_labels1).sum().item()
            total_eval_accuracy += (predicted_labels2 == transfer_labels2).sum().item()
            
            nb_eval_steps += 1
    print("Validation Accuracy: {0:.4f}".format(total_eval_accuracy/nb_eval_steps))

    

  'input_ids1': torch.tensor(item['input_ids1'], dtype=torch.long).squeeze(0),
  'input_ids2': torch.tensor(item['input_ids2'], dtype=torch.long).squeeze(0),
  'attention_mask1': torch.tensor(item['attention_mask1'], dtype=torch.long).squeeze(0),
  'attention_mask2': torch.tensor(item['attention_mask2'], dtype=torch.long).squeeze(0),
  'labels1': torch.tensor(item['labels1'], dtype=torch.long).squeeze(0),
  'labels2': torch.tensor(item['labels2'], dtype=torch.long).squeeze(0),
 10%|█         | 1/10 [1:08:42<10:18:26, 4122.96s/it]

Validation Accuracy: 0.0758


 20%|██        | 2/10 [1:32:48<5:39:44, 2548.07s/it] 

Validation Accuracy: 0.0657


 30%|███       | 3/10 [1:53:01<3:46:07, 1938.26s/it]

Validation Accuracy: 0.0505


 40%|████      | 4/10 [2:13:16<2:45:17, 1652.84s/it]

Validation Accuracy: 0.0354


 50%|█████     | 5/10 [2:26:53<1:52:36, 1351.32s/it]

Validation Accuracy: 0.0253


 60%|██████    | 6/10 [2:40:27<1:17:54, 1168.65s/it]

Validation Accuracy: 0.0253


 70%|███████   | 7/10 [2:53:57<52:34, 1051.53s/it]  

Validation Accuracy: 0.0253


 80%|████████  | 8/10 [3:07:27<32:28, 974.48s/it] 

Validation Accuracy: 0.0253


 90%|█████████ | 9/10 [3:20:55<15:22, 922.53s/it]

Validation Accuracy: 0.0253


100%|██████████| 10/10 [3:34:25<00:00, 1286.56s/it]

Validation Accuracy: 0.0253





In [7]:
for index in range(20):
    
    text_pair = test_df.iloc[index]
    text1 = text_pair['sentence']
    text2 = text_pair['target_text']

    text_id1 = tokenizer(text1, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids'].to(device)
    text_id2 = tokenizer(text2, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids'].to(device)

    encoder_outputs1 = model.encoder(input_ids=text_id1)
    latent_vector1 = encoder_outputs1.last_hidden_state

    encoder_outputs2 = model.encoder(input_ids=text_id2)
    latent_vector2 = encoder_outputs2.last_hidden_state

    # modify latent vector
    latent_vector1_content = latent_vector1[:, :, :384]
    latent_vector1_style = latent_vector1[:, :, 384:]

    latent_vector2_content = latent_vector2[:, :, :384]
    latent_vector2_style = latent_vector2[:, :, 384:]

    # swap style
    modify_latent_vector1 = torch.cat([latent_vector1_content, latent_vector2_style], dim=-1)
    modify_latent_vector2 = torch.cat([latent_vector2_content, latent_vector1_style], dim=-1)

    outputs1 = model(decoder_input_ids = text_id1, encoder_outputs=encoder_outputs1)
    outputs2 = model(decoder_input_ids = text_id2, encoder_outputs=encoder_outputs2)

    logits1 = outputs1.logits
    predicted_token_ids1 = torch.argmax(logits1, dim=-1).to(device).flatten()

    logits2 = outputs2.logits
    predicted_token_ids2 = torch.argmax(logits2, dim=-1).to(device).flatten()

    new_text1 = tokenizer.decode(predicted_token_ids1, skip_special_tokens=True)
    new_text2 = tokenizer.decode(predicted_token_ids2, skip_special_tokens=True)

    print("original text1: ", text1)
    print("transfer text1: ", new_text1)
    print("original text2: ", text2)
    print("transfer text2: ", new_text2)
    print("--------------------------------------------------")

original text1:  You are acounsellor; if you can command these elements tosilence, and work the peace of the present, we willnot hand a rope more; use your authority: if youcannot, give thanks you have lived so long, and makeyourself ready in your cabin for the mischance ofthe hour, if it so hap
transfer text1:  –s,, if you can command these elements tosilence, and work the peace of the present, we willnot hand a rope more; use your authority: if youcannot, give thanks you have lived so long, and makeyourself ready in your cabin for the mischance ofthe hour, if it so hap
original text2:  Actor's wife, Emma Heming-Willis, cast doubt on the report Monday on Twitter
transfer text2:  ,, Emma Heming-Willis, cast doubt on the report Monday on Twitter
--------------------------------------------------
original text1:  Morgan Schneiderlin put in the hard yards for Southampton at Anfield
transfer text1:  put putin put in the hard yards for Southampton Anfield
original text2:  KING EDWARD IV:Lea