# Flair Classification using RNN/GRU

In [0]:
import torch
import torch.nn as nn
from torch import optim
from torch.nn import functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

In [0]:
import pandas as pd

In [0]:
train_df = pd.read_csv('train.csv')
valid_df = pd.read_csv('valid.csv')

## Text preprocess

In [0]:
import pandas as pd
import numpy as np
import re
import emoji
import nltk
nltk.download('wordnet')
nltk.download('stopwords')
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [0]:
def textPreProcess(text, rem_stop=False):
  """
  Function to process text data and remove non neccessary information
  1. converting emojis to text
  2. lower casing
  3. removing punctuations, urls
  4. removing stop words
  5. lemmatization 
  Args:
    text (str): text data
    rem_stop (bool): whether to remove stopping words or not
  Return:
    text (str): processed text data
  """
  # converting emojis to text
  text = emoji.demojize(text)
  # removing empty space at start and end of text and lower casing
  text = text.strip().lower()
  # removing punctuation
  PUNCTUATIONS = '[!()\-[\]{};:"\,<>/?@#$%^&.*_~]'
  text = re.sub(PUNCTUATIONS, "", text)
  # removing url links
  urlPattern = re.compile(r'https?://\S+|www\.\S+')
  urlPattern.sub('', text)

  # updating stopping words list
  if rem_stop:
    stopWords = list(stopwords.words('english'))
  else:
    stopWords = []

  # lemmatizing, removing stop word, removing emojis
  lemmaWords=[]
  Lemma=WordNetLemmatizer()
  for word in text.split():
    # removing stop words
    if word not in stopWords:
      lemmaWords.append(Lemma.lemmatize(word.strip()))
  text = " ".join(lemmaWords)

  return text

category = {'Non-Political':0, 'Coronavirus':1, 'Politics':2,
            'Policy/Economy':3, 'Food':4, 'Science/Technology':5,
            'Business/Finance':6, 'Photography':7, 'Sports':8}

In [0]:
def getClasData(df):
  data = pd.DataFrame(columns=['Text', 'Flair'])
  data['Text'] = df['Headline'].map(str) + df['PostBody'].map(str) + df['Comments'].map(str)
  data['Flair'] = df['Flair'].map(str)

  data['Text'] = data['Text'].apply(lambda text: textPreProcess(text))
  # data['Flair'] = le.transform(data['Flair'])
  data['Flair'] = data['Flair'].apply(lambda flair: category[flair])
  return data

In [0]:
data_trn = getClasData(train_df)
data_vld = getClasData(valid_df)

In [0]:
data_trn.to_csv('clean_train_data.csv')
data_vld.to_csv('clean_valid_data.csv')

In [0]:
import transformers
from transformers import BertTokenizer, DistilBertTokenizer

In [0]:
tokenizer1 = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer2 = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

In [0]:
class RedditData(Dataset):
  def __init__(self, path, maxlen):
    # read csv data file
    self.df = pd.read_csv(path)
    # maximum length of the text to be considered
    self.maxlen = maxlen
    # tokenizer
    self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

  def __len__(self):
    # returns the lenght of dataFrame
    return len(self.df)

  def __getitem__(self, idx):
    """Returns the text tokens and labels based on the index(idx)"""
    # text data
    text = self.df['Text'].iloc[idx]
    # label(flair) of the text
    label = self.df['Flair'].iloc[idx]

    # get token ids and attention mask
    tokens_dict = self.tokenizer.encode(text =text,
                          add_special_tokens=True,
                          max_length = self.maxlen,
                          pad_to_max_length=True)    
    tokens, labels = torch.tensor(tokens_dict), torch.tensor(label)

    return tokens, labels

In [0]:
max_len = 200;
train_set = RedditData('/content/clean_train_data.csv', max_len)
test_set = RedditData('/content/clean_valid_data.csv', max_len)

bs = 8
trainloader = DataLoader(train_set, shuffle = True, batch_size=bs)
testloader = DataLoader(test_set, shuffle = True, batch_size=bs)

## Model Definition

In [0]:
class GRUNet(nn.Module):
  def __init__(self, vocab_size, bs, embed_dim, hidden_dim, n_layers, num_labels, max_len, drop_p,):
    super(GRUNet, self).__init__()
    self.bs = bs
    self.n_layers = n_layers
    self.hidden_dim = hidden_dim 
    
    self.encoder = nn.Embedding(vocab_size, embed_dim)
    self.gru = nn.GRU(input_size = embed_dim,
                      hidden_size = hidden_dim,
                      num_layers=n_layers,
                      dropout = drop_p,
                      batch_first = True,
                      # bidirectional = True,
                      )
    n_in = 1024
    self.decoder_L1 = nn.Linear(hidden_dim*max_len, num_labels)
    # self.decoder_L2 = nn.Linear(n_in, num_labels, bias = False)
    self.dropout = nn.Dropout(drop_p)
    self.sigmoid = nn.Sigmoid()

  def init_hidden(self):
    weight = next(self.parameters()).data
    return weight.new(self.n_layers, self.bs, self.hidden_dim).zero_()

  def forward(self, x):
    self.bs = x.shape[0]
    embeds = self.encoder(x)
    gru_out, hidden = self.gru(embeds, self.init_hidden())
    # gru_out.shape == bs * max_len * hidden_dim
    # hidden.shape == n_layers * bs * hidden_dim
    # hidden shape can be simply understood as gru_out[i] - gru_out[i-1] where i be the size of hidden states

    # reshaping the gru output for the linear layer
    gru_out = gru_out.reshape(self.bs, -1)
    # fully-connected layer(decoder)
    out = self.decoder_L1(gru_out)
    # out = self.decoder_L2(self.dropout(out))
    return self.sigmoid(out)

In [0]:
vocab_size = len(train_set.tokenizer.vocab)
batch_size = bs
embed_dim = 512
hidden_states = 256
num_layers = 2
num_labels = len(category)
dropout = 0.6
max_len = max_len

model = GRUNet(
    vocab_size = vocab_size,
    bs = batch_size,
    embed_dim = embed_dim,
    hidden_dim =  hidden_states,
    n_layers = num_layers,
    num_labels = num_labels,
    max_len = max_len,
    drop_p = dropout
)
model.cuda()

GRUNet(
  (encoder): Embedding(30522, 512)
  (gru): GRU(512, 256, num_layers=2, batch_first=True, dropout=0.6)
  (decoder_L1): Linear(in_features=51200, out_features=9, bias=True)
  (dropout): Dropout(p=0.6, inplace=False)
  (sigmoid): Sigmoid()
)

In [0]:
def getAcc(preds, target):
  """to calculate the accuracy"""
  pred = preds.detach()
  target = target.detach()
  pred = preds.argmax(dim=1)
  return (pred==target).float().mean()

In [0]:
# defining loss function and optimizers
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.003, betas =  (0.9, 0.999))

## Training and Validation

In [0]:
EPOCHS = 30
trn_loss = []
vld_loss = []

for epoch in range(EPOCHS):

  ############ TRAINING ############
  train_loss = 0
  model.train()
  pbar = tqdm(enumerate(trainloader), leave = False, total = len(trainloader))
  for idx, batch in pbar:
    tokens = batch[0].cuda()
    labels = batch[1].cuda()

    output = model(tokens)
    loss = criterion(output, labels)
    trn_loss.append(loss)
    train_loss+=loss.item()
    loss.backward()
    optimizer.step()

  ############ VALIDATION ############
  valid_loss = 0
  acc = 0
  with torch.no_grad():
    model.eval()
    pbar = tqdm(enumerate(testloader), leave = False, total = len(testloader))
    for idx, batch in pbar:
      tokens = batch[0].cuda()
      labels = batch[1].cuda()

      output = model(tokens)
      loss = criterion(output, labels)
      valid_loss += loss.item()
      acc += getAcc(output, labels).item()

  print('Epoch: {}/{:2} | Train loss: {:.5f} | Valid Loss: {:.5f} | Accuracy: {:.4f}'.format(epoch, EPOCHS, train_loss/(len(trainloader)), valid_loss/(len(testloader)), acc/(len(testloader))))

HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 0/30 | Train loss: 2.18586 | Valid Loss: 2.27856 | Accuracy: 0.1250


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 1/30 | Train loss: 2.20651 | Valid Loss: 2.27506 | Accuracy: 0.1250


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 2/30 | Train loss: 2.20645 | Valid Loss: 2.27368 | Accuracy: 0.1250


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 3/30 | Train loss: 2.20520 | Valid Loss: 2.27830 | Accuracy: 0.1212


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 4/30 | Train loss: 2.20580 | Valid Loss: 2.28237 | Accuracy: 0.1187


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 5/30 | Train loss: 2.20431 | Valid Loss: 2.26984 | Accuracy: 0.1237


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 6/30 | Train loss: 2.20453 | Valid Loss: 2.27173 | Accuracy: 0.1275


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 7/30 | Train loss: 2.20419 | Valid Loss: 2.28193 | Accuracy: 0.1163


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 8/30 | Train loss: 2.20366 | Valid Loss: 2.27546 | Accuracy: 0.1263


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 9/30 | Train loss: 2.20323 | Valid Loss: 2.28307 | Accuracy: 0.1300


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 10/30 | Train loss: 2.20343 | Valid Loss: 2.28211 | Accuracy: 0.1388


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 11/30 | Train loss: 2.20627 | Valid Loss: 2.27716 | Accuracy: 0.1375


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 12/30 | Train loss: 2.20261 | Valid Loss: 2.29674 | Accuracy: 0.1350


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

Epoch: 13/30 | Train loss: 2.20047 | Valid Loss: 2.28032 | Accuracy: 0.1288


HBox(children=(IntProgress(value=0, max=2207), HTML(value='')))

KeyboardInterrupt: ignored