<a href="https://colab.research.google.com/github/Ramprabu95/RNN-based-Text-Classification-in-presence-of-noise/blob/main/rcnn_text_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class RCNN(nn.Module):
    """
    Recurrent Convolutional Neural Networks for Text Classification (2015)
    """
    def __init__(self, vocab_size, embedding_dim, hidden_size, hidden_size_linear, class_num, dropout):
        super(RCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, bidirectional=True, dropout=dropout)
        self.W = nn.Linear(embedding_dim + 2*hidden_size, hidden_size_linear)
        self.tanh = nn.Tanh()
        self.fc = nn.Linear(hidden_size_linear, class_num)

    def forward(self, x):
        # x = |bs, seq_len|
        x_emb = self.embedding(x)
        # x_emb = |bs, seq_len, embedding_dim|
        output, _ = self.lstm(x_emb)
        # output = |bs, seq_len, 2*hidden_size|
        output = torch.cat([output, x_emb], 2)
        # output = |bs, seq_len, embedding_dim + 2*hidden_size|
        output = self.tanh(self.W(output)).transpose(1, 2)
        # output = |bs, seq_len, hidden_size_linear| -> |bs, hidden_size_linear, seq_len|
        output = F.max_pool1d(output, output.size(2)).squeeze(2)
        # output = |bs, hidden_size_linear|
        output = self.fc(output)
        # output = |bs, class_num|
        return output

In [None]:

import torch
from torch.utils.data import Dataset


class CustomTextDataset(Dataset):
    def __init__(self, texts, labels, dictionary):
        # Unknown Token is index 1 (<UNK>)
        self.x = [[dictionary.get(token, 1) for token in token_list] for token_list in texts]
        self.y = labels

    def __len__(self):
        """Return the data length"""
        return len(self.x)

    def __getitem__(self, idx):
        """Return one item on the index"""
        return self.x[idx], self.y[idx]


def collate_fn(data, args, pad_idx=0):
    """Padding"""
    texts, labels = zip(*data)
    texts = [s + [pad_idx] * (args.max_len - len(s)) if len(s) < args.max_len else s[:args.max_len] for s in texts]
    return torch.LongTensor(texts), torch.LongTensor(labels)

In [None]:
import pandas as pd
import re
import nltk
from nltk.tokenize import word_tokenize
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
nltk.download('punkt')


def read_file(file_path):
    """
    Read function for AG NEWS Dataset
    """
    data = pd.read_csv(file_path, names=["class", "title", "description"])
    texts = list(data['title'].values + ' ' + data['description'].values)
    texts = [word_tokenize(preprocess_text(sentence)) for sentence in texts]
    labels = [label-1 for label in list(data['class'].values)]  # label : 1~4  -> label : 0~3
    return texts, labels


def preprocess_text(string):
    """
    reference : https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
    """
    string = string.lower()
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip()


def metrics(dataloader, losses, correct, y_hats, targets):
    avg_loss = losses / len(dataloader)
    accuracy = correct / len(dataloader.dataset) * 100
    precision = precision_score(targets, y_hats, average='macro')
    recall = recall_score(targets, y_hats, average='macro')
    f1 = f1_score(targets, y_hats, average='macro')
    cm = confusion_matrix(targets, y_hats)
    return avg_loss, accuracy, precision, recall, f1, cm

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:

import os
import logging
import torch
import torch.nn.functional as F

from utils import metrics

logging.basicConfig(format='%(asctime)s -  %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
logger = logging.getLogger(__name__)


def train(model, optimizer, train_dataloader, valid_dataloader, args):
    best_f1 = 0
    logger.info('Start Training!')
    for epoch in range(1, args.epochs+1):
        model.train()
        for step, (x, y) in enumerate(train_dataloader):
            x, y = x.to(args.device), y.to(args.device)
            pred = model(x)
            loss = F.cross_entropy(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (step+1) % 200 == 0:
                logger.info(f'|EPOCHS| {epoch:>}/{args.epochs} |STEP| {step+1:>4}/{len(train_dataloader)} |LOSS| {loss.item():>.4f}')

        avg_loss, accuracy, _, _, f1, _ = evaluate(model, valid_dataloader, args)
        logger.info('-'*50)
        logger.info(f'|* VALID SET *| |VAL LOSS| {avg_loss:>.4f} |ACC| {accuracy:>.4f} |F1| {f1:>.4f}')
        logger.info('-'*50)

        if f1 > best_f1:
            best_f1 = f1
            logger.info(f'Saving best model... F1 score is {best_f1:>.4f}')
            if not os.path.isdir(args.model_save_path):
                os.mkdir(args.model_save_path)
            torch.save(model.state_dict(), os.path.join(args.model_save_path, "best.pt"))
            logger.info('Model saved!')


def evaluate(model, valid_dataloader, args):
    with torch.no_grad():
        model.eval()
        losses, correct = 0, 0
        y_hats, targets = [], []
        for x, y in valid_dataloader:
            x, y = x.to(args.device), y.to(args.device)
            pred = model(x)
            loss = F.cross_entropy(pred, y)
            losses += loss.item()

            y_hat = torch.max(pred, 1)[1]
            y_hats += y_hat.tolist()
            targets += y.tolist()
            correct += (y_hat == y).sum().item()

    avg_loss, accuracy, precision, recall, f1, cm = metrics(valid_dataloader, losses, correct, y_hats, targets)
    return avg_loss, accuracy, precision, recall, f1, cm



ImportError: ignored