# A pre-trained Bert model use case

- The pre-trained model comes from Huggingface
- The dataset comes from kaggle 'Sentiment140 dataset with 1.6 million tweets'

## Exploring the dataset

- Read in the file
- I took 1000 tweets, 500 from positive samples and 500 from negative samples
- Save them to a new file, I try to catch the characteristics of the data by watching the 1000 samples

In [None]:
import pandas as pd
import numpy as np 
import re 

ori_file_path = "D:/CodeBase-User/VScode-workspace/BERT-Notebooks/Dataset/training.1600000.processed.noemoticon.csv"
slice_file_path = "D:/CodeBase-User/VScode-workspace/BERT-Notebooks/Dataset/noemoticon-slice-1000.csv"

In [None]:
df_ori = pd.read_csv(ori_file_path, encoding='ISO-8859-1', header=None)
df_ori.columns = ['sentiment','id','date','flag','user','tweet']
df_ori.sentiment = df_ori.sentiment.map({4:1,0:0})
df = pd.concat([df_ori[df_ori.sentiment == 0][:500], df_ori[df_ori.sentiment == 1][:500]])
df.to_csv(slice_file_path)

- Here is an interesting problem of the dataset, 
- There exist some samples that labeled both posistive and negative.

In [None]:
# the same tweet but have different labels (both 0 and 1)
df = pd.read_csv(slice_file_path)
tweet_unique = df.tweet.drop_duplicates(keep=False)
records_with_same_tweet = df[True ^ df.tweet.isin(tweet_unique)]
records_with_same_tweet[['sentiment', 'user', 'tweet']]

In [None]:
df = df[True ^ df.tweet.isin(records_with_same_tweet.tweet)]
df.sentiment.value_counts().plot(kind='bar')

- Check the pretty common 'special tokens' in tweets, such as topics `#`, mentions `@`, urls `http/https` and characters suffer from encoding errors `&`

In [None]:
pattern_mention = r'^@[0-9a-zA-Z_]+'
pattern_topic = r'^#[0-9a-zA-Z]+'
pattern_sign = r'^&[a-z]+;'
pattern_url = r'^http\S+'
mention_count = 0
for tweet in df.tweet:
    mention_res = re.match(pattern_mention, tweet, flags=re.IGNORECASE|re.S)
    if mention_res:
        mention_count += len(mention_res.group())
    topic_res = re.match(pattern_topic, tweet, flags=re.IGNORECASE|re.S)
    if topic_res:
        print(topic_res.group())
    sign_res = re.match(pattern_sign, tweet, flags=re.IGNORECASE|re.S)
    if sign_res:
        print(sign_res.group())
    url_res = re.match(pattern_url, tweet, flags=re.IGNORECASE|re.S)
    if url_res:
        print(url_res.group())
print("The number of mentions: ", mention_count)    

- Before doing text cleaning, I check the frequency of the words in the tweets
- Just to see if this would change after cleaning the data

In [None]:
import nltk
import string
from nltk.corpus import stopwords

stop_words = set(stopwords.words('english'))

In [None]:
df_words = df.tweet.apply(lambda tweet: len([word for word in str(tweet).lower().split()]))
df_stopwords = df.tweet.apply(lambda tweet: len([word for word in str(tweet).lower().split() if word in stop_words]))
df_non_stopwords = df.tweet.apply(lambda tweet: len([word for word in str(tweet).lower().split() if word not in stop_words]))
print("Number of words: ", df_words.sum())
print("Number of stopwords: ", df_stopwords.sum())
print("Number of non stopwords: ", df_non_stopwords.sum())

In [None]:
from wordcloud import WordCloud
import matplotlib.pyplot as plt

def word_count(df):
    word_cloud = WordCloud(max_words=200, background_color='white',stopwords=stop_words,colormap='rainbow',height=1000,width=700)
    tweets = df.tweet.values
    word_cloud.generate(str(tweets).lower())
    fig = plt.figure()
    plt.imshow(word_cloud)
    fig.set_figwidth(10)
    fig.set_figheight(10)
    plt.show()

word_count(df)

## Data cleaning

- There are a whole lot of strange abbreviations in tweets. 
- People always write pretty informally when sending tweets.
- Of course, typos are huge problems, too.

In [1]:
# use the lowercase 
df.tweet = df.tweet.str.lower()
# abbreviation check list
abbreviations = {
        "$" : " dollar ",
        "€" : " euro ",
        "4ao" : "for adults only",
        "a.m" : "before midday",
        "a3" : "anytime anywhere anyplace",
        "aamof" : "as a matter of fact",
        "acct" : "account",
        "adih" : "another day in hell",
        "afaic" : "as far as i am concerned",
        "afaict" : "as far as i can tell",
        "afaik" : "as far as i know",
        "afair" : "as far as i remember",
        "afk" : "away from keyboard",
        "app" : "application",
        "approx" : "approximately",
        "apps" : "applications",
        "asap" : "as soon as possible",
        "asl" : "age, sex, location",
        "atk" : "at the keyboard",
        "ave." : "avenue",
        "aymm" : "are you my mother",
        "ayor" : "at your own risk", 
        "b&b" : "bed and breakfast",
        "b+b" : "bed and breakfast",
        "b.c" : "before christ",
        "b2b" : "business to business",
        "b2c" : "business to customer",
        "b4" : "before",
        "b4n" : "bye for now",
        "b@u" : "back at you",
        "bae" : "before anyone else",
        "bak" : "back at keyboard",
        "bbbg" : "bye bye be good",
        "bbc" : "british broadcasting corporation",
        "bbias" : "be back in a second",
        "bbl" : "be back later",
        "bbs" : "be back soon",
        "be4" : "before",
        "bfn" : "bye for now",
        "blvd" : "boulevard",
        "bout" : "about",
        "brb" : "be right back",
        "bros" : "brothers",
        "brt" : "be right there",
        "bsaaw" : "big smile and a wink",
        "btw" : "by the way",
        "bwl" : "bursting with laughter",
        "c/o" : "care of",
        "cet" : "central european time",
        "cf" : "compare",
        "cia" : "central intelligence agency",
        "csl" : "can not stop laughing",
        "cu" : "see you",
        "cul8r" : "see you later",
        "cv" : "curriculum vitae",
        "cwot" : "complete waste of time",
        "cya" : "see you",
        "cyt" : "see you tomorrow",
        "dae" : "does anyone else",
        "dbmib" : "do not bother me i am busy",
        "diy" : "do it yourself",
        "dm" : "direct message",
        "dwh" : "during work hours",
        "e123" : "easy as one two three",
        "eet" : "eastern european time",
        "eg" : "example",
        "embm" : "early morning business meeting",
        "encl" : "enclosed",
        "encl." : "enclosed",
        "etc" : "and so on",
        "faq" : "frequently asked questions",
        "fawc" : "for anyone who cares",
        "fb" : "facebook",
        "fc" : "fingers crossed",
        "fig" : "figure",
        "fimh" : "forever in my heart", 
        "ft." : "feet",
        "ft" : "featuring",
        "ftl" : "for the loss",
        "ftw" : "for the win",
        "fwiw" : "for what it is worth",
        "fyi" : "for your information",
        "g9" : "genius",
        "gahoy" : "get a hold of yourself",
        "gal" : "get a life",
        "gcse" : "general certificate of secondary education",
        "gfn" : "gone for now",
        "gg" : "good game",
        "gl" : "good luck",
        "glhf" : "good luck have fun",
        "gmt" : "greenwich mean time",
        "gmta" : "great minds think alike",
        "gn" : "good night",
        "g.o.a.t" : "greatest of all time",
        "goat" : "greatest of all time",
        "goi" : "get over it",
        "gps" : "global positioning system",
        "gr8" : "great",
        "gratz" : "congratulations",
        "gyal" : "girl",
        "h&c" : "hot and cold",
        "hp" : "horsepower",
        "hr" : "hour",
        "hrh" : "his royal highness",
        "ht" : "height",
        "ibrb" : "i will be right back",
        "ic" : "i see",
        "icq" : "i seek you",
        "icymi" : "in case you missed it",
        "idc" : "i do not care",
        "idgadf" : "i do not give a damn fuck",
        "idgaf" : "i do not give a fuck",
        "idk" : "i do not know",
        "ie" : "that is",
        "i.e" : "that is",
        "ifyp" : "i feel your pain",
        "ig" : "instagram",
        "iirc" : "if i remember correctly",
        "ilu" : "i love you",
        "ily" : "i love you",
        "imho" : "in my humble opinion",
        "imo" : "in my opinion",
        "imu" : "i miss you",
        "iow" : "in other words",
        "irl" : "in real life",
        "j4f" : "just for fun",
        "jic" : "just in case",
        "jk" : "just kidding",
        "jsyk" : "just so you know",
        "l8r" : "later",
        "lb" : "pound",
        "lbs" : "pounds",
        "ldr" : "long distance relationship",
        "lmao" : "laugh my ass off",
        "lmfao" : "laugh my fucking ass off",
        "lol" : "laughing out loud",
        "ltd" : "limited",
        "ltns" : "long time no see",
        "m8" : "mate",
        "mf" : "motherfucker",
        "mfs" : "motherfuckers",
        "mfw" : "my face when",
        "mofo" : "motherfucker",
        "mph" : "miles per hour",
        "mr" : "mister",
        "mrw" : "my reaction when",
        "ms" : "miss",
        "mte" : "my thoughts exactly",
        "nagi" : "not a good idea",
        "nbc" : "national broadcasting company",
        "nbd" : "not big deal",
        "nfs" : "not for sale",
        "ngl" : "not going to lie",
        "nhs" : "national health service",
        "nrn" : "no reply necessary",
        "nsfl" : "not safe for life",
        "nsfw" : "not safe for work",
        "nth" : "nice to have",
        "nvr" : "never",
        "nyc" : "new york city",
        "oc" : "original content",
        "og" : "original",
        "ohp" : "overhead projector",
        "oic" : "oh i see",
        "omdb" : "over my dead body",
        "omg" : "oh my god",
        "omw" : "on my way",
        "p.a" : "per annum",
        "p.m" : "after midday",
        "pm" : "prime minister",
        "poc" : "people of color",
        "pov" : "point of view",
        "pp" : "pages",
        "ppl" : "people",
        "prw" : "parents are watching",
        "ps" : "postscript",
        "pt" : "point",
        "ptb" : "please text back",
        "pto" : "please turn over",
        "qpsa" : "what happens", 
        "ratchet" : "rude",
        "rbtl" : "read between the lines",
        "rlrt" : "real life retweet", 
        "rofl" : "rolling on the floor laughing",
        "roflol" : "rolling on the floor laughing out loud",
        "rotflmao" : "rolling on the floor laughing my ass off",
        "rt" : "retweet",
        "ruok" : "are you ok",
        "sfw" : "safe for work",
        "sk8" : "skate",
        "smh" : "shake my head",
        "sq" : "square",
        "srsly" : "seriously", 
        "ssdd" : "same stuff different day",
        "tbh" : "to be honest",
        "tbs" : "tablespooful",
        "tbsp" : "tablespooful",
        "tfw" : "that feeling when",
        "thks" : "thank you",
        "tho" : "though",
        "thx" : "thank you",
        "tia" : "thanks in advance",
        "til" : "today i learned",
        "tl;dr" : "too long i did not read",
        "tldr" : "too long i did not read",
        "tmb" : "tweet me back",
        "tntl" : "trying not to laugh",
        "ttyl" : "talk to you later",
        "u" : "you",
        "u2" : "you too",
        "u4e" : "yours for ever",
        "utc" : "coordinated universal time",
        "w/" : "with",
        "w/o" : "without",
        "w8" : "wait",
        "wassup" : "what is up",
        "wb" : "welcome back",
        "wtf" : "what the fuck",
        "wtg" : "way to go",
        "wtpa" : "where the party at",
        "wuf" : "where are you from",
        "wuzup" : "what is up",
        "wywh" : "wish you were here",
        "yd" : "yard",
        "ygtr" : "you got that right",
        "ynk" : "you never know",
        "zzz" : "sleeping bored and tired"
    }

def get_tokens(tweet):
    # remove special chars
    tweet = re.sub(r"\x89Û_", "", tweet)
    tweet = re.sub(r"\x89ÛÒ", "", tweet)
    tweet = re.sub(r"\x89ÛÓ", "", tweet)
    tweet = re.sub(r"\x89ÛÏ", "", tweet)
    tweet = re.sub(r"\x89Û÷", "", tweet)
    tweet = re.sub(r"\x89Ûª", "", tweet)
    tweet = re.sub(r"\x89Û\x9d", "", tweet)
    tweet = re.sub(r"å_", "", tweet)
    tweet = re.sub(r"\x89Û¢", "", tweet)
    tweet = re.sub(r"\x89Û¢åÊ", "", tweet)
    tweet = re.sub(r"åÊ", "", tweet)
    tweet = re.sub(r"åÈ", "", tweet)
    tweet = re.sub(r"Ì©", "e", tweet)
    tweet = re.sub(r"å¨", "", tweet)
    tweet = re.sub(r"åÇ", "", tweet)
    tweet = re.sub(r"åÀ", "", tweet)
    # remove contractions
    tweet = re.sub(r"let\x89Ûªs", "let us", tweet)
    tweet = re.sub(r"let's", "let us", tweet)
    tweet = re.sub(r"he's", "he is", tweet)
    tweet = re.sub(r"there's", "there is", tweet)
    tweet = re.sub(r"we're", "we are", tweet)
    tweet = re.sub(r"that's", "that is", tweet)
    tweet = re.sub(r"that\x89Ûªs", "that is", tweet)
    tweet = re.sub(r"won't", "will not", tweet)
    tweet = re.sub(r"they're", "they are", tweet)
    tweet = re.sub(r"can't", "cannot", tweet)
    tweet = re.sub(r"can\x89Ûªt", "cannot", tweet)
    tweet = re.sub(r"wasn't", "was not", tweet)
    tweet = re.sub(r"don't", "do not", tweet)
    tweet = re.sub(r"donå«t", "do not", tweet)  
    tweet = re.sub(r"don\x89Ûªt", "do not", tweet)
    tweet = re.sub(r"aren't", "are not", tweet)
    tweet = re.sub(r"isn't", "is not", tweet)
    tweet = re.sub(r"what's", "what is", tweet)
    tweet = re.sub(r"haven't", "have not", tweet)
    tweet = re.sub(r"hasn't", "has not", tweet)
    tweet = re.sub(r"it's", "it is", tweet)
    tweet = re.sub(r"it\x89Ûªs", "it is", tweet)
    tweet = re.sub(r"you're", "you are", tweet)
    tweet = re.sub(r"you\x89Ûªre", "you are", tweet)
    tweet = re.sub(r"i'm", "i am", tweet)
    tweet = re.sub(r"i\x89Ûªm", "i am", tweet)
    tweet = re.sub(r"shouldn't", "should not", tweet)
    tweet = re.sub(r"wouldn't", "would not", tweet)
    tweet = re.sub(r"wouldn\x89Ûªt", "would not", tweet)
    tweet = re.sub(r"here's", "here is", tweet)
    tweet = re.sub(r"here\x89Ûªs", "here is", tweet)
    tweet = re.sub(r"where's", "where is", tweet)
    tweet = re.sub(r"you've", "you have", tweet)
    tweet = re.sub(r"you\x89Ûªve", "you have", tweet)
    tweet = re.sub(r"youve", "you have", tweet)
    tweet = re.sub(r"couldn't", "could not", tweet)
    tweet = re.sub(r"we've", "we have", tweet)
    tweet = re.sub(r"doesn't", "does not", tweet)
    tweet = re.sub(r"doesn\x89Ûªt", "does not", tweet)
    tweet = re.sub(r"who's", "who is", tweet)
    tweet = re.sub(r"i've", "i have", tweet)
    tweet = re.sub(r"i\x89Ûªve", "i have", tweet)
    tweet = re.sub(r"y'all", "you all", tweet)
    tweet = re.sub(r"would've", "would have", tweet)
    tweet = re.sub(r"it'll", "it will", tweet)
    tweet = re.sub(r"we'll", "we will", tweet)
    tweet = re.sub(r"he'll", "he will", tweet)
    tweet = re.sub(r"weren't", "were not", tweet)
    tweet = re.sub(r"didn't", "did not", tweet)
    tweet = re.sub(r"they'll", "they will", tweet)
    tweet = re.sub(r"they'd", "they would", tweet)
    tweet = re.sub(r"they've", "they have", tweet)
    tweet = re.sub(r"i'd", "i would", tweet)
    tweet = re.sub(r"I\x89Ûªd", "I would", tweet)
    tweet = re.sub(r"should've", "should have", tweet)
    tweet = re.sub(r"we'd", "we would", tweet)
    tweet = re.sub(r"i'll", "i will", tweet)
    tweet = re.sub(r"^ill$", "i will", tweet)
    tweet = re.sub(r"you'll", "you will", tweet)
    tweet = re.sub(r"you\x89Ûªll", "you will", tweet)    
    tweet = re.sub(r"ain't", "am not", tweet)    
    tweet = re.sub(r"you'd", "you would", tweet)
    tweet = re.sub(r"could've", "could have", tweet)
    tweet = re.sub(r"mÌ¼sica", "music", tweet)
    tweet = re.sub(r"some1", "someone", tweet)
    tweet = re.sub(r"yrs", "years", tweet)
    tweet = re.sub(r"hrs", "hours", tweet)
    tweet = re.sub(r"2morow|2moro", "tomorrow", tweet)
    tweet = re.sub(r"2day", "today", tweet)
    tweet = re.sub(r"4got|4gotten", "forget", tweet)
    tweet = re.sub(r"b-day|bday", "b-day", tweet)
    tweet = re.sub(r"mother's", "mother", tweet)
    tweet = re.sub(r"mom's", "mom", tweet)
    tweet = re.sub(r"dad's", "dad", tweet)
    tweet = re.sub(r"^[h|a]+$", "haha", tweet)
    tweet = re.sub(r"lmao|lolz|rofl", "lol", tweet)
    tweet = re.sub(r"thanx|thnx|thx", "thanks", tweet)
    tweet = re.sub(r'all[l]+', "all", tweet)
    tweet = re.sub(r'so[o]+', "so", tweet)
    tweet = re.sub(r'why[y]+', "why", tweet)
    tweet = re.sub(r'way[y]+', "way", tweet)
    tweet = re.sub(r'will[l]+', "will", tweet)
    tweet = re.sub(r'oo[o]+h', "ooh", tweet)
    tweet = re.sub(r'hey[y]+', "hey", tweet)
    tweet = re.sub(r"boo[o]+m", "boom", tweet)
    tweet = re.sub(r"co[o]+ld", "cold", tweet)
    tweet = re.sub(r"goo[o]+d", "good", tweet)
    # deal with some abbreviations
    words = tweet.split()
    tweet = ' '.join([abbreviations[word] if word in abbreviations.keys() else word for word in words])
    # character entity references
    tweet = re.sub(r"&gt;", ">", tweet)
    tweet = re.sub(r"&lt;", "<", tweet)
    tweet = re.sub(r"&amp;", "&", tweet)
    # typos, slang and informal abbreviations
    tweet = re.sub(r"w/e", "whatever", tweet)
    tweet = re.sub(r"usagov", "usa government", tweet)
    tweet = re.sub(r"<3", "love", tweet)
    tweet = re.sub(r"trfc", "traffic", tweet)
    # remove urls
    tweet = re.sub(r"http\S+", "", tweet)
    # remove mentions
    tweet = re.sub(r'^@[0-9a-zA-Z_]+', "", tweet)
    # words with punctuations and special characters
    for punc in string.punctuation:
        tweet = tweet.replace(punc, '')
    # ... and ..
    tweet = tweet.replace('...', ' ... ')
    if '...' not in tweet:
        tweet = tweet.replace('..', ' ... ')
    # get the tokens
    tweet = [word for word in str(tweet).split() if word not in stop_words]
    tweet = ' '.join(tweet)
    return tweet

df['tokens'] = df.tweet.apply(lambda tweet: get_tokens(tweet))
print("done")

- Write the results to the target file.
- And check the words' frequency after cleaning.

In [None]:
df = df[['sentiment', 'tokens']]
df.rename(columns={'tokens': 'tweet'}, inplace=True)
df.to_csv("D:/CodeBase-User/VScode-workspace/BERT-Notebooks/Dataset/noemoticon-cleaned-1000.csv")
word_count(df)

## Model building (Pytorch & transformers)

- Now using the pre-trained transformers models from the huggingface with Pytorch.
- I choose to use the 'bert-base-uncased' model
- The dataset only has two classes: positive & negative.

In [None]:
import torch
from torch.nn import functional as F
from transformers import BertTokenizer, BertModel, BertConfig, AdamW
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

max_seq_length = 128
num_class = 1
batch_size = 32
epoch = 10

bert_model_path = "E:/PreTrainedModels/Huggingface/bert-base-uncased"
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
df = pd.read_csv("D:/CodeBase-User/VScode-workspace/BERT-Notebooks/Dataset/noemoticon-cleaned-1000.csv")
df.head()

- Implement a method to tokenize the tweets.
- Not directly using the batch tokenize function in the transformer's pre-built tokenizer: Show the structure of each input parts.

In [None]:
def covert_tweet_format(df, tokenizer, max_seq_length):
    all_tokens = []
    all_masks = []
    all_labels = []
    for tweet in df.tweet:
        tweet = tokenizer.tokenize(str(tweet))
        # No longer the the max length limitation
        tweet = tweet[:max_seq_length - 2]
        # '[CLS], token_1, ~, token_n, [SEP]'
        input_sequence = ['[CLS]'] + tweet + ['[SEP]']
        pad_len = max_seq_length - len(input_sequence)
        # Get the id of each token
        tokens = tokenizer.convert_tokens_to_ids(input_sequence)
        # Now length = max_seq_length
        tokens += [0] * pad_len
        # All places that have contents is marked as 1, otherwise 0
        pad_masks = [1] * len(input_sequence) + [0] * pad_len
        all_tokens.append(tokens)
        all_masks.append(pad_masks)
    for label in df.sentiment:
        all_labels.append([label])
    return torch.tensor(all_tokens, dtype=torch.long), torch.tensor(all_masks, dtype=torch.long), torch.tensor(all_labels, dtype=torch.float)

- Now I load the model from the huggingface repo. 
- It can be loaded locally once downloaded.
- Load the random sampler for batch training.

In [None]:
bert_model = BertModel.from_pretrained(bert_model_path)
bert_config = BertConfig.from_pretrained(bert_model_path)
tokenizer = BertTokenizer.from_pretrained(bert_model_path)

all_input_ids, all_input_mask, all_labels = covert_tweet_format(df, tokenizer, max_seq_length)

train_data = TensorDataset(all_input_ids, all_input_mask, all_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

- Define the customerized extra layers of the model 
- Using the pretrained bert model as my bert layer and freeze its weight
- The learning rate of bias and layer norm layers will not decay during training

In [None]:
class TweetClassifier(torch.nn.Module):
    def __init__(self, bert_model, bert_config, num_class):
        super(TweetClassifier, self).__init__()
        self.bert_layer = bert_model
        self.output = torch.nn.Sequential(
            torch.nn.Dropout(0.3),
            torch.nn.Linear(bert_config.hidden_size, bert_config.hidden_size//2),
            torch.nn.Linear(bert_config.hidden_size//2, bert_config.hidden_size),
            torch.nn.Linear(bert_config.hidden_size, num_class),
            torch.nn.Sigmoid()
        )
    def forward(self, input_ids, attn_mask):
        bert_out = self.bert_layer(input_ids, attention_mask=attn_mask)
        out = self.output(bert_out[-1])
        return out

In [None]:
tweet_classifier = TweetClassifier(bert_model=bert_model, bert_config=bert_config, num_class=num_class)
tweet_classifier = tweet_classifier.to(device)
# Not change the params in the pre-trained model
freeze_param = ['bert_layer']
for n, p in tweet_classifier.named_parameters():
    if any(nd in n for nd in freeze_param):
        p.requires_grad = False
# Adjust the learning rate decay for some layers
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in tweet_classifier.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in tweet_classifier.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
tweet_classifier.train()

- Finally, define the trainiing process.

In [None]:
for _ in range(epoch):
    epoch_loss = 0
    for batch, (token_ids, attn_mask, label) in enumerate(train_dataloader):
        # keep all the parameters in the same device
        token_ids = token_ids.to(device)
        attn_mask = attn_mask.to(device)
        label = label.to(device)
        # the output will be in the same device with the model
        outputs = tweet_classifier(token_ids, attn_mask)
        loss = F.binary_cross_entropy(outputs, label)
        # do the backprop and update the parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.cpu().data.numpy()
        if batch % 30 == 0:
                print("Current batch loss: ", loss.cpu().data.numpy())
    print("Now epoch total loss: ", epoch_loss)

## Model building (inherit from BertPreTrainedModel)

- Another example of using pre-trained models from huggingface, using the transformers with Pytorch.
- This time I inherit the `BertPreTrainedModel` class and use `from_pretrained()` method to create model.
- What's more, I directly use the tokenizer from the hugginface and add k-fold training strategy. 

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from sklearn.model_selection import StratifiedKFold
from torch.nn import functional as F
from transformers import BertTokenizer, BertModel, BertConfig, AdamW, BertPreTrainedModel


max_seq_length = 128
num_class = 1
batch_size = 32
k = 3
seed = 731

bert_path = "E:/PreTrainedModels/Huggingface/bert-base-uncased"
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
df_train = pd.read_csv("D:/CodeBase-User/VScode-workspace/BERT-Notebooks/Dataset/noemoticon-cleaned-1000.csv")
df_train.reset_index()
df_train = df_train.sample(frac=1)
df_train.head()

In [None]:
class TweetClassifier(BertPreTrainedModel):
    def __init__(self, bert_model, bert_config, num_class):
        super(TweetClassifier, self).__init__(bert_config)
        self.bert_layer = BertModel.from_pretrained(bert_model, config=bert_config)
        self.dropout = torch.nn.Dropout(0.3)
        self.output = torch.nn.Sequential(            
            torch.nn.Linear(bert_config.hidden_size, bert_config.hidden_size//2),
            torch.nn.Linear(bert_config.hidden_size//2, bert_config.hidden_size),
            torch.nn.Linear(bert_config.hidden_size, num_class),
            torch.nn.Sigmoid()
        )
    def forward(self, input_ids, attn_mask, token_type_ids):
        bert_out = self.bert_layer(input_ids, attention_mask=attn_mask, token_type_ids=token_type_ids)
        output = self.dropout(bert_out[-1])
        output = self.output(output)
        return output

- Directly use the tokenizer given by huggingface.
- Also, another hepler function to get rid of the unexisting tweets

In [None]:
def covert_tweet_format(tweets, sentiments, tokenizer, max_seq_length):
    tweet_list = []
    for tweet in tweets:
        tweet_list.append(str(tweet))
    encoded_inputs = tokenizer(tweet_list, padding='max_length', max_length=max_seq_length, truncation=True, return_tensors="pt")
    all_labels = []
    for label in sentiments:
        all_labels.append([label])
    return encoded_inputs, torch.tensor(all_labels, dtype=torch.float)

- Define the training set and the optimizer
- Load the pretrained model

In [None]:
tokenizer = BertTokenizer.from_pretrained(bert_path)
k_fold = StratifiedKFold(n_splits=k, random_state=seed, shuffle=True)
bert_config = BertConfig.from_pretrained(bert_path)
tweet_classifier = TweetClassifier(bert_path, bert_config, num_class)
tweet_classifier = tweet_classifier.to(device)

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in tweet_classifier.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in tweet_classifier.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)

tweet_classifier.train()

- Now train the Bert model

In [None]:
for fold, (trn_idx, val_idx) in enumerate(k_fold.split(df_train.tweet, df.sentiment)):
    trn_idx = clean_unexist_idx(trn_idx)
    encoded_inputs, labels = covert_tweet_format(df.loc[trn_idx, 'tweet'], df.loc[trn_idx, 'sentiment'], tokenizer, max_seq_length)
    train_data = TensorDataset(encoded_inputs['input_ids'], encoded_inputs['attention_mask'], encoded_inputs['token_type_ids'], labels)
    train_sampler = for fold, (trn_idx, val_idx) in enumerate(k_fold.split(df_train.tweet, df_train.sentiment)):
    encoded_inputs, labels = covert_tweet_format(df_train.loc[trn_idx, 'tweet'], df_train.loc[trn_idx, 'sentiment'], tokenizer, max_seq_length)
    train_data = TensorDataset(encoded_inputs['input_ids'], encoded_inputs['attention_mask'], encoded_inputs['token_type_ids'], labels)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
    for epoch in range(epoch_num):
        epoch_loss = 0
        for batch, (token_ids, attn_mask, token_types, labels) in enumerate(train_dataloader):
            token_ids = token_ids.to(device)
            attn_mask = attn_mask.to(device)
            token_types = token_types.to(device)
            labels = labels.to(device)
            outputs = tweet_classifier(token_ids, attn_mask, token_types)
            loss = F.binary_cross_entropy(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.cpu().data.numpy()
            if batch % 30 == 0:
                print("Current batch loss: ", loss.cpu().data.numpy())
        print("Now epoch: ", epoch + 1, " current loss: ", epoch_loss)