In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
import json
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import copy

import pandas as pd

import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [2]:
class SlotIntentDataset(Dataset):
    
    def __init__(self, datapath):
        self.data = []
        with open(datapath, 'r') as jsonl_file:
            for line in jsonl_file:
                self.data.append(json.loads(line))
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return (self.data[idx]['input'], ", ".join(self.data[idx]['user_contacts']), self.data[idx]['output'])

def dl_collate_fn(batch):
    return list(batch)

## Training

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_path = '/kaggle/input/col772a3-data/A3'

In [4]:
train_ds = SlotIntentDataset(f'{data_path}/train.jsonl')
val_ds = SlotIntentDataset(f'{data_path}/dev.jsonl')

In [5]:
DEBUG = False
if DEBUG:
    train_ds.data = train_ds.data[:128]
    val_ds.data = val_ds.data[:32]

In [6]:
train_dl = DataLoader(train_ds, batch_size=16, num_workers=2, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=16, num_workers=2, shuffle=False)

In [7]:
model = transformers.BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(device)
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
optimizer = optim.Adam(model.parameters(), lr=5e-5)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [8]:
def process_batch(batch, tokenizer):
    encoder_strs = [f'[{b}] {a}' for a,b in zip(batch[0], batch[1])]
    decoder_strs = batch[2]
    
    encoder_toks = tokenizer(encoder_strs, padding=True, truncation=True, return_tensors='pt').to(device)
    decoder_toks = tokenizer(decoder_strs, padding=True, truncation=True, return_tensors='pt').to(device)
    return encoder_toks, decoder_toks

In [11]:
def train(model, tokenizer, train_dl, val_dl, optimizer, scheduler=None, max_epochs=20, patience_lim=2):

    best_model = None
    best_val_loss = 10000
    val_losses = []
    train_losses = []
    patience = 0

    for epoch in range(max_epochs):

        print(f'Epoch {epoch+1}:')
        train_loss = torch.tensor(0, dtype=torch.float, device=device)
        model.train()
        for batch in tqdm(train_dl):
            encoder_toks, decoder_toks = process_batch(batch, tokenizer)
            
            optimizer.zero_grad()
            loss = model(
                **encoder_toks,
                decoder_input_ids=decoder_toks['input_ids'],
                decoder_attention_mask=decoder_toks['attention_mask'],
                labels=decoder_toks['input_ids']
            ).loss
            loss.backward()
            optimizer.step()

            train_loss += loss.detach()
        
        if scheduler:
            scheduler.step()

        train_loss = train_loss.cpu()
        train_loss /= len(train_dl)
        print(f' Train Loss: {train_loss}')
        train_losses.append(train_loss)

        val_loss = torch.tensor(0, dtype=torch.float, device=device)
        true_labels = []
        pred_labels = []
        model.eval()
        for batch in tqdm(val_dl):
            encoder_toks, decoder_toks = process_batch(batch, tokenizer)
            
            loss = model(
                **encoder_toks,
                decoder_input_ids=decoder_toks['input_ids'],
                decoder_attention_mask=decoder_toks['attention_mask'],
                labels=decoder_toks['input_ids']
            ).loss

            val_loss += loss.detach()
            
        val_loss = val_loss.cpu()
        val_loss /= len(val_dl)
        val_losses.append(val_loss)

        print(f' Val Loss: {val_loss}')
        print('')

        # early stopping
        if val_loss >= best_val_loss:
            if patience >= patience_lim:
                break
            else:
                patience += 1
        else:
            patience = 0
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            best_model = best_model.cpu()
    
    return best_model, (train_losses, val_losses)

In [None]:
best_model, (train_losses, val_losses) = train(model, tokenizer, train_dl, val_dl, optimizer)
torch.save(best_model, 'intent-slot-bart.pt')

Epoch 1:


  0%|          | 0/8 [00:00<?, ?it/s]

 Train Loss: 1.668185830116272


  0%|          | 0/2 [00:00<?, ?it/s]

 Val Loss: 0.5320673584938049

Epoch 2:


  0%|          | 0/8 [00:00<?, ?it/s]

 Train Loss: 0.359955757856369


  0%|          | 0/2 [00:00<?, ?it/s]

 Val Loss: 0.16559645533561707

Epoch 3:


  0%|          | 0/8 [00:00<?, ?it/s]

 Train Loss: 0.13355059921741486


  0%|          | 0/2 [00:00<?, ?it/s]

 Val Loss: 0.08556962013244629

Epoch 4:


  0%|          | 0/8 [00:00<?, ?it/s]

 Train Loss: 0.0669875368475914


  0%|          | 0/2 [00:00<?, ?it/s]

 Val Loss: 0.05432933568954468

Epoch 5:


  0%|          | 0/8 [00:00<?, ?it/s]

 Train Loss: 0.03958814591169357


  0%|          | 0/2 [00:00<?, ?it/s]

 Val Loss: 0.03992871195077896



Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x795682d5fb00>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1430, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/opt/conda/lib/python3.7/multiprocessing/popen_fork.py", line 45, in wait
    if not wait([self.sentinel], timeout):
  File "/opt/conda/lib/python3.7/multiprocessing/connection.py", line 921, in wait
    ready = selector.select(timeout)
  File "/opt/conda/lib/python3.7/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 


Epoch 6:


  0%|          | 0/8 [00:00<?, ?it/s]

 Train Loss: 0.026360616087913513


  0%|          | 0/2 [00:00<?, ?it/s]

 Val Loss: 0.033907756209373474

Epoch 7:


  0%|          | 0/8 [00:00<?, ?it/s]

 Train Loss: 0.0185256190598011


  0%|          | 0/2 [00:00<?, ?it/s]

 Val Loss: 0.029269695281982422



In [None]:
import matplotlib.pyplot as plt
plt.plot(train_losses)
plt.plot(val_losses)