In [None]:
import os
import re
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.sampler import SubsetRandomSampler
from model.classifier import RNNClassifier
from torchsummary import summary
from dataset.chat_dataset import preprocess_data, ChatDataset, create_vocab
from collections import Counter

%load_ext autoreload
%autoreload 2


# 1. Loading and Processing Chat Data

Chat data should be in the dataset folder and named "_chat.txt". First open the text file in VSCode and check if [U+200E] characters are present, if so remove all occurences.

In [None]:
path = os.path.abspath(os.getcwd())
chat_dir = os.path.join(path, "dataset")
sender_indices = preprocess_data(chat_dir)

print(sender_indices)

# 2. Tokenizing Data and Creating Vocabulary

Now that we have preprocessed the data we can create our vocabulary.

In [None]:
vocab, tokenized_data, lines = create_vocab(chat_dir, sender_indices, threshold=3)

# 3. Creating the Dataset


In [None]:
indexed_data = []
for tokens, label in tokenized_data:
    indices = [vocab.get(token, vocab['<unk>']) for token in tokens]
    # the token that is not in vocab get assigned <unk>
    indexed_data.append((indices, label))


In [None]:
combined_data = []

for i in range(len(lines)):
    data = (lines[i], tokenized_data[i][0], indexed_data[i][0], indexed_data[i][1])
    combined_data.append(data)


dataset = ChatDataset(combined_data)

# 4. Creating the DataLoaders

In [None]:
def collate(batch):
    assert isinstance(batch, list)
    data = pad_sequence([b['data'] for b in batch])
    lengths = torch.tensor([len(b['data']) for b in batch])
    label = torch.stack([b['label'] for b in batch])
    return {
        'data': data,
        'label': label,
        'lengths': lengths
    }

In [None]:
batch_size = 64
validation_split = .1
shuffle_dataset = True
random_seed = 42

dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler, collate_fn=collate)
val_loader = DataLoader(dataset, batch_size=batch_size,
                        sampler=valid_sampler, collate_fn=collate)

# 5. Create and Train Classifier

## Evaluation Metrics

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
@torch.no_grad()
def compute_accuracy(model, data_loader):
    corrects = 0
    total = 0
    device = next(model.parameters()).device

    for i, x in enumerate(data_loader):
        input = x['data'].to(device)
        lengths = x['lengths']
        label = x['label'].to(device)
        pred = model(input, lengths)
        pred = torch.argmax(pred, dim = 1)
        corrects += torch.count_nonzero(torch.eq(pred, label))
        total += label.numel()

        if i > 0 and i % 100 == 0:
            print('Step {} / {}'.format(i, len(data_loader)))

    return corrects / total

## Train Model

In [None]:
from tqdm import tqdm
from torch.optim.lr_scheduler import ExponentialLR

model_save_path = os.path.join(path, "model", "chat_model.p")

# model = torch.load(model_save_path)
model = RNNClassifier(len(vocab), 100, 64, len(sender_indices.keys()), num_layers=1)

# Move model to the device we are using
model = model.to(device)
gclip = 8

def train(model, optimizer, train_loader, val_loader, loss_func, sched=None, epochs=10):
    model.train()
    for epoch_id in range(epochs):
        with tqdm(train_loader, unit="batch") as tepoch:
            for data in tepoch:
                tepoch.set_description(f'Epoch {epoch_id + 1}')
                data, labels, lengths = data['data'].to(
                    device), data['label'].to(device), data['lengths'].to(device)
                optimizer.zero_grad()
                outputs = model(data, lengths)
                outputs = outputs.to(device)
                loss = loss_func(outputs, labels)
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), gclip)
                optimizer.step()
                tepoch.set_postfix(loss=loss.item())
            if sched:
                sched.step()
            

In [None]:
from prettytable import PrettyTable
def count_parameters(model):

    table = PrettyTable(["Mod name", "Parameters Listed"])
    t_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        param = parameter.numel()
        table.add_row([name, param])
        t_params += param
    print(table)
    print(f"Sum of trained paramters: {t_params}")
    return t_params
count_parameters(model)


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
sched = ExponentialLR(optimizer, gamma=0.9)
loss_func = nn.CrossEntropyLoss()

train(model, optimizer, train_loader, val_loader, loss_func, epochs=20, sched=sched)

## Test Model

In [None]:

print("accuracy on test set: {}".format(compute_accuracy(model, val_loader)))

In [None]:

torch.save(model, model_save_path)

# Predict Input

In [None]:
from dataset.chat_dataset import tokenize
from time import sleep
model.eval()

pred_indices = {value:key for (key, value) in sender_indices.items()}

text = input("Enter text: ")
tokens = tokenize(text.lower())
indices = [vocab.get(token, vocab['<unk>']) for token in tokens]
sequence = torch.tensor([indices]).permute(1,0).to(device)
pred = model.predict(sequence)
print(f'{pred_indices[pred.item()]}: {text}')
