<a href="https://colab.research.google.com/github/Hotckiss/NLP/blob/master/hw4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import re
import csv
import string
import nltk
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data.dataset import Dataset
from torch.nn.utils.rnn import pad_sequence
from matplotlib import pyplot as plt
from collections import defaultdict
from sklearn.model_selection import train_test_split
from nltk.stem.snowball import SnowballStemmer
from google.colab import drive
from nltk.corpus import stopwords
from nltk.tokenize import wordpunct_tokenize as tokenize
from nltk.stem.porter import PorterStemmer
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy.spatial.distance import cosine

In [0]:
drive.mount('./gdrive')

In [0]:
nltk.download('punkt')
device = torch.device('cuda')
stemmer = SnowballStemmer("russian")
nltk.download('stopwords')
stop_words = set(stopwords.words('russian'))

In [0]:
word_to_idx = defaultdict(int)
separator = 65000
pad_value = 65001

In [0]:
def plot_loss_values(loss_train, loss_val):
    plt.plot(np.arange(len(loss_train)), loss_train, color='blue', label='train')
    plt.plot(np.arange(0, len(loss_train), len(loss_train) / epoch_cnt), loss_val, color='red', label='validation')
    plt.legend()
    plt.title("Loss values")
    plt.xlabel("iteration")
    plt.ylabel("loss")
    plt.show()

In [0]:
def locate(text, answer):
  sent_text = nltk.sent_tokenize(text, language="russian")
  for sent in sent_text:
    idx = sent.lower().find(answer.lower())
    if idx != -1:
      return sent, idx

  for i in range(len(sent_text) - 1):
    sent = sent_text[i] + " " + sent_text[i + 1]
    idx = sent.lower().find(answer.lower())
    if idx != -1:
      return sent, idx

  return "", -1

In [0]:
def read_dataset():
    dataset = []

    with open("./gdrive/My Drive/train_qa.csv", encoding='utf-8') as input_file:
        i = -1
        reader = csv.reader(input_file)
        next(reader)

        for row in reader:
            i += 1
            
            query, answer = row[3], row[4]

            if query[-1] == '?':
                query = query[:-1]
            if answer[-3:] == "...":
                answer = answer[:-3]
            if answer[:3] == "...":
                answer = answer[3:]
            if answer[0] == ' ' or answer[0] in string.punctuation:
                answer = answer[1:]
            if answer[-1] in string.punctuation:
                answer = answer[:-1]

            answer = answer.strip()
            sent, ans_start = locate(row[2], answer)

            if ans_start != -1:
                dataset.append({'paragraph': re.split('(\W)', sent),
                                'query': re.split('(\W)', query),
                                'answer': (ans_start, len(answer))})
                                
    return dataset


In [0]:
dataset = read_dataset()

In [0]:
def my_collate1(row, modify):
    idx = []

    def proc(word):
        stemmed = stemmer.stem(word.lower())
        if modify and stemmed not in word_to_idx:
            word_to_idx[stemmed] = len(word_to_idx) + 1
        idx.append(word_to_idx[stemmed])

    for word in row['paragraph']:
        proc(word)
    
    idx.append(separator)

    for word in row['query']:
        proc(word)

    return torch.tensor(idx)

In [0]:
def my_collate2(row):
    pos, idx = [], 0

    for word in row['paragraph']:
        l = len(word)
        pos.append((idx, l))
        idx += l

    return pos

In [0]:
def pad_dataset(dataset, update_vocab):
    num_dataset = [my_collate1(item, update_vocab) for item in tqdm(dataset)]
    return pad_sequence(num_dataset, batch_first=True, padding_value=pad_value)

In [0]:
def pos_dataset(dataset):
    return [my_collate2(item) for item in tqdm(dataset)]

In [0]:
dataset_padded = pad_dataset(dataset, True)
dataset_pos = pos_dataset(dataset)

In [0]:
def join_dataset(dataset_padded, dataset_pos, dataset):
    dataset_joined = []

    for sent, pos, row in tqdm(zip(dataset_padded, dataset_pos, dataset)):
        ans_start, ans_len = row['answer']
        y = np.array([0, 0])

        for pos_i in range(len(pos)):
            if pos[pos_i][0] == ans_start:
                y[0] = pos_i
            if pos[pos_i][0] == ans_start + ans_len:
                y[1] = pos_i - 1
            
        dataset_joined.append(np.append(sent.numpy(), y))

    return dataset_joined

In [0]:
dataset_joined = join_dataset(dataset_padded, dataset_pos, dataset)

In [0]:
train, val = train_test_split(dataset_joined, test_size=0.2, random_state=42)

In [0]:
class MyLSTM(nn.Module):
    def __init__(self, vocab_size):
        super(MyLSTM, self).__init__()
        self.emb = nn.Embedding(vocab_size, 64)
        self.lstm = nn.LSTM(64, 64, num_layers=4, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(128, 2)
    
    def forward(self, x):
        x = self.emb(x)
        x, _ = self.lstm(x)
        x = self.fc(x)
        
        return F.log_softmax(torch.transpose(x, 1, 2), dim=2)

In [0]:
def train_(model, train, val, optimizer, loss_function, epoch_cnt, batch_size):
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size)
    loss_train, loss_val = [], []
            
    for epoch in tnrange(epoch_cnt, desc='Epoch'):
        for batch_data in train_loader:
            x, y = batch_data[:, :-2].to(device), batch_data[:, -2:].to(device)
            optimizer.zero_grad()
            output = model(x.long())
            y1, y2 = y[:, 0].reshape(-1), y[:, 1].reshape(-1)
            loss1 = loss_function(output[:, 0], y1.long())
            loss2 = loss_function(output[:, 1], y2.long())
            loss = (loss1 + loss2) / 2
            loss.backward()

            loss_train.append(loss.item())
            nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()

        with torch.no_grad():
            loss_values = []
            for batch_data in val_loader:
                x, y = batch_data[:, :-2].to(device), batch_data[:, -2:].to(device)
                output = model(x.long())
                y1, y2 = y[:, 0].reshape(-1), y[:, 1].reshape(-1)
                loss1 = loss_function(output[:, 0], y1.long())
                loss2 = loss_function(output[:, 1], y2.long())
                loss = (loss1 + loss2) / 2

                loss_values.append(loss.item())
            loss_val.append(np.mean(np.array(loss_values)))

    return loss_train, loss_val 


In [0]:
vocab_size = max(pad_value, len(word_to_idx)) + 1
epoch_cnt = 10
batch_size = 64

In [0]:
model = MyLSTM(vocab_size)
model = model.float()
model = model.to(device)

In [0]:
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())

In [0]:
loss_train, loss_val =\
    train_(model, train, val, optimizer, loss_function, epoch_cnt, batch_size)

In [0]:
plot_loss_values(loss_train, loss_val)

In [0]:
def best_tfidf(text, query):
  sents = [elem for elem in list(nltk.sent_tokenize(text, language="russian"))]
  docs = sents + [query]
  docs = [[stemmer.stem(i.lower()) for i in tokenize(d.translate(str.maketrans('','',string.punctuation))) if i.lower() not in stop_words] for d in docs]
  docs = [' '.join(i) for i in docs]
  tf_idf = TfidfVectorizer().fit_transform(docs)

  l = len(sents)
  minimum = (1, None)
  for i in range(l):
    if cosine(tf_idf[i].todense(), tf_idf[l].todense()) < minimum[0]:
      minimum = (cosine(tf_idf[i].todense(), tf_idf[l].todense()), i)
  if minimum[1] == None:
    return None
  return sents[minimum[1]]

In [0]:
def read_input():
    input_filename = "dataset_281937_1 (3).txt"
    dataset = []

    with open(input_filename, encoding='utf-8') as input_file:
        reader = csv.reader(input_file, delimiter='\t')
        input_file.seek(0)
        next(reader)
        for row in reader:
            query = row[3]
            if query[-1] == '?':
                query = query[:-1]

            sent = best_tfidf(row[2], query)

            if sent == None:
                sent = row[2]

            dataset.append({'paragraph': re.split('(\W)', sent),
                      'query': re.split('(\W)', query),
                      'query_id': row[1]})

    return dataset


In [0]:
test = read_input()
test_padded = pad_dataset(test, False)
test_pos = pos_dataset(test)

In [0]:
with torch.no_grad():
    test_loader = torch.utils.data.DataLoader(test_padded, batch_size=batch_size)
    ans = None
    
    for batch_data in test_loader:
        x = batch_data.to(device)
        output = model(x.long())
        _, ansx = output.max(dim=2)
        ansx = ansx.cpu().numpy()
        if ans is None:
            ans = ansx
        else:
            ans = np.append(ans, ansx, axis=0)

In [0]:
with open("res.txt", "w", encoding="utf-8") as output_file:
    for tags, row in zip(ans, test):
        start, end = tags
        if start > end:
            start, end = end, start
        if end >= len(row['paragraph']):
            start, end = 0, 0
        output_file.write("%s\t%s\n" % (row['query_id'], "".join(row['paragraph'][start:end + 1])))
