In [1]:
import pandas as pd
import os
import torch
import numpy as np
from transformers import XLNetModel, XLNetTokenizer, XLNetForSequenceClassification, AdamW, XLNetConfig
#from tokenizers import BertWordPieceTokenizer
from torch import nn
np.random.seed(321)

In [2]:
#BertModel.from_pretrained("./Pretrained Models/bert_weights/")

In [3]:
ROOT_PATH = "~/Research/CellularLint"
DATA_PATH = "./Data/SNLI/"
PRETRAINED_PATH = "./Pretrained Models/xlnet_weights/"
SAVE_MODEL_AT = "./saved_models/xlnet"
PRETRAINED_TOKENIZER = "saved_models/xlnet"
MODEL_PATH = None

In [4]:
df_train = pd.read_csv(os.path.join(DATA_PATH,"snli_1.0_train.csv"))
df_dev = pd.read_csv(os.path.join(DATA_PATH,"snli_1.0_dev.csv"))
df_test = pd.read_csv(os.path.join(DATA_PATH,"snli_1.0_test.csv"))

In [5]:
df_train = df_train.iloc[:500]
df_dev = df_train.copy().iloc[:200]

In [6]:
#tokenizer = BertTokenizerFast.from_pretrained(os.path.join(ROOT_PATH,PRETRAINED_TOKENIZER))
#tokenizer = BertTokenizerFast.from_pretrained(os.path.join(ROOT_PATH,PRETRAINED_PATH))
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")

In [7]:
labels = {'contradiction':1,
          'entailment':0,
          'neutral':2,
          }
NUM_LABELS = len(labels)

In [8]:
def xlnet_encoding(sequence):
    return tokenizer.encode(sequence, add_special_tokens = False)
def str_to_int_list(data):
    return list(map(int, data))

In [9]:
df_train['token_type'] = df_train['token_type'].str.split()
df_train['token_type'] = df_train['token_type'].apply(str_to_int_list)

df_dev['token_type'] = df_dev['token_type'].str.split()
df_dev['token_type'] = df_dev['token_type'].apply(str_to_int_list)

df_test['token_type'] = df_test['token_type'].str.split()
df_test['token_type'] = df_test['token_type'].apply(str_to_int_list)


In [10]:
df_train['attention_mask'] = df_train['attention_mask'].str.split()
df_train['attention_mask'] = df_train['attention_mask'].apply(str_to_int_list)

df_dev['attention_mask'] = df_dev['attention_mask'].str.split()
df_dev['attention_mask'] = df_dev['attention_mask'].apply(str_to_int_list)

df_test['attention_mask'] = df_test['attention_mask'].str.split()
df_test['attention_mask'] = df_test['attention_mask'].apply(str_to_int_list)


In [11]:
df_train['input_ids'] = df_train['sequence'].apply(xlnet_encoding)
df_dev['input_ids'] = df_dev['sequence'].apply(xlnet_encoding)
df_test['input_ids'] = df_test['sequence'].apply(xlnet_encoding)

In [12]:
max_length = 512
class Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.texts = []
        self.labels = [labels[label] for label in df['gold_label']]
        for _, row in df.iterrows():
            token_type_ids = row['token_type']
            token_type_ids += [0] * (max_length - len(token_type_ids))
            attention_mask = row['attention_mask']
            attention_mask += [0] * (max_length - len(attention_mask))
            input_ids = tokenizer.encode(
                row['sequence'],
                add_special_tokens=False,
                padding='max_length',
                max_length=max_length,
                truncation=True,
                return_tensors="pt"
            )
            datadict = {
                'input_ids': input_ids.squeeze(0),
                'token_type_ids': torch.tensor(token_type_ids),
                'attention_mask': torch.tensor(attention_mask)
            }
            self.texts.append(datadict)

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        batch_text = self.texts[idx]
        batch_y = self.labels[idx]
        return batch_text, batch_y

In [13]:
class XLNetClassifier(nn.Module):

    def __init__(self, load_path = None, dropout=0.5):

        super(XLNetClassifier, self).__init__()

        #self.bert = BertModel.from_pretrained(os.path.join(ROOT_PATH, load_path))
        self.xlnet = XLNetModel.from_pretrained(load_path)
        
        #self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, NUM_LABELS)
        
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, input_id, mask, token_type_id):
        output = self.xlnet(input_ids= input_id, attention_mask = mask, token_type_ids = token_type_id,return_dict = False)
        #_, pooled_output = self.xlnet(input_ids= input_id, attention_mask = mask, return_dict = False)
        
        pooled_output = output[0][:,-1,:] #Representation from last token
        dropout_output = self.dropout(pooled_output)
        
        linear_output = self.linear(dropout_output)
        
        final_layer = self.softmax(linear_output)
        
        return final_layer
    
    def save(self, save_dir, tokenizer, model_name = "model_xlnet.pt"):
        
        os.makedirs(save_dir, exist_ok=True)
        # Save model weights
        #model_path = os.path.join(save_dir, model_name)
        #torch.save(self.state_dict(), model_path)
        self.xlnet.save_pretrained(save_dir)
        
        # Save tokenizer
        tokenizer.save_pretrained(save_dir)

        # Save other related information
        #config_path = os.path.join(save_dir, "config.json")
        #self.xlnet.config.to_json_file(config_path)

    def load(self, load_dir, is_eval = True, model_name = "model_xlnet.pt"):
        # Load tokenizer
        tokenizer = XLNetTokenizer.from_pretrained(load_dir)

        # Load other related information
        config_path = os.path.join(load_dir, "config.json")
        config = XLNetConfig.from_json_file(config_path)
        self.xlnet = XLNetModel(config)
        if is_eval:
            self.xlnet.eval()  # Set to evaluation mode

        # Load model weights
        model_path = os.path.join(load_dir, model_name)
        self.load_state_dict(torch.load(model_path))

        # Update the tokenizer
        self.xlnet.resize_token_embeddings(len(tokenizer))

In [14]:
from tqdm import tqdm
learning_rates = [1e-5] #[5e-6, 1e-5, 2e-5, 3e-5, 5e-5]
batch_sizes = [8] #[16, 24, 32, 40]

def train(model, train_data, val_data, learning_rate, epochs, batch_size):
    #out_file.write(f"-------------Starting with LR = {learning_rate} and BS = {batch_size}-----------------\n")
    best_acc_val = -99999

    train, val = Dataset(train_data), Dataset(val_data)
    
    train_dataloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val, batch_size=batch_size)
    
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr= learning_rate)
    
    if use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()
        
    for epoch_num in range(epochs):

        total_acc_train = 0
        total_loss_train = 0
        
        for train_input, train_label in tqdm(train_dataloader):
                #print(train_input)
                train_label = train_label.to(device)
                
                input_id = train_input['input_ids'].squeeze(1).to(device)
                mask = train_input['attention_mask'].squeeze(1).to(device)
                token_type_id = train_input['token_type_ids'].squeeze(1).to(device)

                output = model(input_id, mask, token_type_id)
                #logits = output.logits #For BertForSequenceClassification
                logits = output
                #print(f'training logits: {logits}')
                #print("training outputs")
                #print(output)
                optimizer.zero_grad()
                
                batch_loss = criterion(logits, train_label.long())
                total_loss_train += batch_loss.item()
                #print(f'prediction: {torch.argmax(logits, dim=1)}')
                #print(f"train labels: {train_label}")
                acc = (torch.argmax(logits, dim=1) == train_label).sum().item()
                total_acc_train += acc

                
                batch_loss.backward()
                optimizer.step()
            
        total_acc_val = 0
        total_loss_val = 0

        with torch.no_grad():

            for val_input, val_label in val_dataloader:

                val_label = val_label.to(device)
                input_id = val_input['input_ids'].squeeze(1).to(device)
                mask = val_input['attention_mask'].squeeze(1).to(device)
                token_type_id = val_input['token_type_ids'].squeeze(1).to(device)

                output = model(input_id, mask, token_type_id)
                #logits = output.logits #For BertForSequenceClassification
                logits = output
                #print(val_label)
                #print(f'validating logits: {logits}')
                #print(f'prediction: {torch.argmax(logits, dim=1)}')

                batch_loss = criterion(logits, val_label.long())
                total_loss_val += batch_loss.item()
                #print(output)
                acc = (torch.argmax(logits, dim=1) == val_label).sum().item()
                total_acc_val += acc

        print(
            f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} \
            | Train Accuracy: {total_acc_train / len(train_data): .3f} \
            | Val Loss: {total_loss_val / len(val_data): .3f} \
            | Val Accuracy: {total_acc_val / len(val_data): .3f}')

        if(total_acc_val / len(val_data) > best_acc_val):
            best_acc_val = total_acc_val / len(val_data)
            #model.bert.save_pretrained(os.path.join(ROOT_PATH,"saved_models/finetune_p1/bert/"))
            #tokenizer.save_pretrained(os.path.join(ROOT_PATH,"saved_models/finetune_p1/bert/"))

            model.save(save_dir = SAVE_MODEL_AT, tokenizer = tokenizer)
            print("Found a better model")

EPOCHS = 8
for LR in learning_rates:
    for bs in batch_sizes:
        #model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels = NUM_LABELS)
        model = XLNetClassifier(load_path = PRETRAINED_PATH)
        train(model, df_train, df_dev, LR, EPOCHS, bs)

Some weights of the model checkpoint at ./Pretrained Models/xlnet_weights/ were not used when initializing XLNetModel: ['lm_loss.bias', 'lm_loss.weight']
- This IS expected if you are initializing XLNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.56s/it]


Epochs: 1 | Train Loss:  0.038             | Train Accuracy:  0.348             | Val Loss:  0.043             | Val Accuracy:  0.275
Found a better model


100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.53s/it]


Epochs: 2 | Train Loss:  0.038             | Train Accuracy:  0.314             | Val Loss:  0.043             | Val Accuracy:  0.285
Found a better model


100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.54s/it]


Epochs: 3 | Train Loss:  0.037             | Train Accuracy:  0.360             | Val Loss:  0.041             | Val Accuracy:  0.365
Found a better model


100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.54s/it]


Epochs: 4 | Train Loss:  0.038             | Train Accuracy:  0.332             | Val Loss:  0.042             | Val Accuracy:  0.305


100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.53s/it]


Epochs: 5 | Train Loss:  0.038             | Train Accuracy:  0.338             | Val Loss:  0.041             | Val Accuracy:  0.330


100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.53s/it]


Epochs: 6 | Train Loss:  0.038             | Train Accuracy:  0.316             | Val Loss:  0.042             | Val Accuracy:  0.340


100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.53s/it]


Epochs: 7 | Train Loss:  0.038             | Train Accuracy:  0.330             | Val Loss:  0.041             | Val Accuracy:  0.345


100%|███████████████████████████████████████████████████████████████████████████████████| 16/16 [00:24<00:00,  1.53s/it]


Epochs: 8 | Train Loss:  0.038             | Train Accuracy:  0.338             | Val Loss:  0.040             | Val Accuracy:  0.370
Found a better model
