In [1]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

import os
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np

ASAP_PATH = 'asap-aes/training_set_rel3.tsv'
DATA_PATH = 'data'
CHECKPOINT_PATH = 'saved_models/multi_scale_model'                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         

os.makedirs(DATA_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
pl.seed_everything(42)
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
print('Device:', device)

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 42


Device: cuda:0


# Preprocess Data

In [2]:
# Load data
asap_data = pd.read_table(ASAP_PATH, encoding='latin')

# Clean data
asap_data.dropna(axis=1, inplace=True)

# Split data by `essay_set`
grouped_data = [asap_data[asap_data['essay_set'] == i] for i in range(1, 9)]
for set_id, group_data in enumerate(grouped_data):
    csv_path = os.path.join(DATA_PATH, 'essay_set%d.csv' % (set_id+1))
    if not os.path.exists(csv_path):
        group_data['score'] = group_data['domain1_score']
        group_data.to_csv(csv_path)

# Normalize Score to 0-10
asap_ranges = {
    1: (2.0, 12.0),
    2: (1.0, 6.0),
    3: (0.0, 3.0),
    4: (0.0, 3.0),
    5: (0.0, 4.0),
    6: (0.0, 4.0),
    7: (0.0, 30.0),
    8: (0.0, 60.0),
    9: (0.5, 9.0),
    10: (1.0, 24.0),
}

def normalize_score(data):
    scores = data['domain1_score']
    score_range = [asap_ranges[essay_set] for essay_set in data['essay_set']]
    normalize_scores = [10 * (score - min_score) / (max_score - min_score) for score, (min_score, max_score) in zip(scores, score_range)]
    return normalize_scores

# Save whole data
asap_data['score'] = normalize_score(asap_data)
csv_path = os.path.join(DATA_PATH, 'essay_set_all.csv')
if not os.path.exists(csv_path):
    asap_data.to_csv(csv_path)

# Build Dataset

In [3]:
class AESDataset(data.Dataset):
    
    def __init__(self, data, tokenizer, max_input_length=512, chunk_size_list=[90, 30, 130, 10]):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.chunk_size_list = chunk_size_list
        self.tokenized_data = dict()
        self.tokenized_documents = None
        self.encode_documents()
        
    def encode_documents(self):
        self.tokenized_data['doc'] = self.encode_documents_by_chunk(self.max_input_length, doc=True)
        self.tokenized_data['seg'] = dict()
        for chunk_size in self.chunk_size_list:
            self.tokenized_data['seg'][chunk_size] = self.encode_documents_by_chunk(chunk_size, doc=False)
        
        
    
    def encode_documents_by_chunk(self, chunk_size, doc=False):
        if self.tokenized_documents is None:
            self.tokenized_documents = [tokenizer.tokenize(doc) for doc in self.data[0]]
        max_segment_length = chunk_size - 2
        segment_num = math.ceil(max([len(doc) / max_segment_length for doc in self.tokenized_documents]))
        total_input_ids = torch.zeros(size=(len(self.data[0]), segment_num, chunk_size), 
                             dtype=torch.long)
        total_token_type_ids = torch.zeros(size=(len(self.data[0]), segment_num, chunk_size), 
                             dtype=torch.long)
        total_attention_mask = torch.zeros(size=(len(self.data[0]), segment_num, chunk_size), 
                             dtype=torch.long)
        for doc_index in tqdm(range(len(self.tokenized_documents))):
            tokenized_document = self.tokenized_documents[doc_index]
            for seq_index, seg_begin in enumerate(range(0, len(tokenized_document), max_segment_length)):
                seg_end = seg_begin + max_segment_length
                tokens = tokenized_document[seg_begin: seg_end]
                tokens = ['CLS'] + tokens + ['SEP']
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                input_ids = token_ids + [0] * (chunk_size - len(token_ids))
                token_type_ids = [0] * len(input_ids)
                attention_mask = [1] * len(token_ids) + [0] * (chunk_size - len(token_ids))

                total_input_ids[doc_index][seq_index] = torch.LongTensor(input_ids).unsqueeze(0)
                total_token_type_ids[doc_index][seq_index] = torch.LongTensor(token_type_ids).unsqueeze(0)
                total_attention_mask[doc_index][seq_index] = torch.LongTensor(attention_mask).unsqueeze(0)

        if doc:
            return total_input_ids[:, 0:1, :], total_token_type_ids[:, 0:1, :], total_attention_mask[:, 0:1, :]
                                              
        return total_input_ids, total_token_type_ids, total_attention_mask
    
    def __getitem__(self, index):
        sample = {'labels': self.data[1][index], 'chunk_size_list': self.chunk_size_list}
        
        # Document-Scale Data
        sample['doc'] = dict()
        sample['doc']['input_ids'] = self.tokenized_data['doc'][0][index]
        sample['doc']['token_type_ids'] = self.tokenized_data['doc'][1][index]
        sample['doc']['attention_mask'] = self.tokenized_data['doc'][2][index]
        
        # Segment-Scale Data
        sample['seg'] = dict()
        for chunk_size in self.chunk_size_list:
            sample['seg']['seg-%d' % chunk_size]= dict()
            sample['seg']['seg-%d' % chunk_size]['input_ids'] = self.tokenized_data['seg'][chunk_size][0][index:index+1]
            sample['seg']['seg-%d' % chunk_size]['token_type_ids'] = self.tokenized_data['seg'][chunk_size][1][index:index+1]
            sample['seg']['seg-%d' % chunk_size]['attention_mask'] = self.tokenized_data['seg'][chunk_size][2][index:index+1]

        return sample
    
    
    def __len__(self):
        return len(self.data[0])
    
    
    def __str__(self):     
        return str(self[0].keys())
        

In [4]:
def callate_fn(examples):
    batched_examples = defaultdict(dict)
    batched_examples['labels'] = []
    batched_examples['doc'] = dict()
    batched_examples['seg'] = dict()
    batched_examples['chunk_size_list'] = examples[0]['chunk_size_list']
    seg_scales = ['seg-%d'%chunk_size for chunk_size in batched_examples['chunk_size_list']]
    for example in examples:
        batched_examples['labels'].append(example['labels'])

        if batched_examples['doc'].get('input_ids') is None:
            batched_examples['doc']['input_ids'] = example['doc']['input_ids']
            batched_examples['doc']['token_type_ids'] = example['doc']['token_type_ids']
            batched_examples['doc']['attention_mask'] = example['doc']['attention_mask']
        else:
            batched_examples['doc']['input_ids'] = torch.cat([batched_examples['doc']['input_ids'], example['doc']['input_ids']])
            batched_examples['doc']['token_type_ids'] = torch.cat([batched_examples['doc']['token_type_ids'], example['doc']['token_type_ids']])
            batched_examples['doc']['attention_mask'] = torch.cat([batched_examples['doc']['attention_mask'], example['doc']['attention_mask']])

        for scale in seg_scales:
            if batched_examples['seg'].get(scale):
                batched_examples['seg'][scale]['input_ids'] = torch.cat([batched_examples['seg'][scale]['input_ids'], 
                                                                         example['seg'][scale]['input_ids']]
                                                                         , dim=0)
                batched_examples['seg'][scale]['token_type_ids'] = torch.cat([batched_examples['seg'][scale]['token_type_ids'], 
                                                                              example['seg'][scale]['token_type_ids']]
                                                                          , dim=0)
                batched_examples['seg'][scale]['attention_mask'] = torch.cat([batched_examples['seg'][scale]['attention_mask'], 
                                                                         example['seg'][scale]['attention_mask']]
                                                                          , dim=0)
            else:
                batched_examples['seg'][scale] = dict()
                batched_examples['seg'][scale]['input_ids'] = example['seg'][scale]['input_ids']
                batched_examples['seg'][scale]['token_type_ids'] = example['seg'][scale]['token_type_ids']
                batched_examples['seg'][scale]['attention_mask'] = example['seg'][scale]['attention_mask']

    batched_examples['labels'] = torch.FloatTensor(batched_examples['labels'])
    batched_examples = dict(batched_examples)

    return batched_examples
            

In [16]:
from transformers import BertTokenizer
import math
from tqdm.autonotebook import tqdm
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = AESDataset((asap_data['essay'], asap_data['score']), tokenizer)


  0%|                                                                                                                                                                             | 0/12976 [00:00<?, ?it/s][A
  1%|█▍                                                                                                                                                               | 112/12976 [00:00<00:11, 1114.09it/s][A
  2%|██▊                                                                                                                                                              | 224/12976 [00:00<00:11, 1112.73it/s][A
  3%|████▏                                                                                                                                                            | 336/12976 [00:00<00:11, 1056.65it/s][A
  3%|█████▍                                                                                                                                                           |

In [21]:
dataset[0]['seg']['seg-90']['input_ids'].shape

torch.Size([1, 17, 90])

# Build Base Model

In [25]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(7)

In [41]:
class DocumentTokenCombineEncoder(nn.Module):
    
    def __init__(self, bert_model_config):
        super(DocumentTokenCombineEncoder, self).__init__(bert_model_config)
        self.bert_model_config =  bert_model_config
        self.bert = BertModel(self.bert_model_config)
        self.classifier = nn.Sequential(
            nn.Dropout(p=self.bert_model_config.hidden_dropout_prob),
            nn.Linear(self.bert_model_config.hidden_size * 2, 1)
        )
        self.classifier.apply(init_weights)
    
    def forward(self, inputs):
        last_hidden_state, pooler_output = self.bert(**inputs)
        token_embeds = torch.max(last_hidden_state, dim=1).values
        document_token_embeds = torch.cat([token_embeds, pooler_output], 1)
        predictions = self.classifier(document_token_embeds.view(document_token_embeds.shape[0], -1))
        return predictions
    

In [28]:
class SegmentLSTMAttentionEncoder(nn.Module):
    
    def __init__(self, bert_model_config):
        super(SegmentLSTMAttentionEncoder, self).__init__(bert_model_config)
        self.bert_model_config = bert_model_config
        self.bert = BertModel(self.bert_model_config)
        self.lstm = nn.LSTM(self.bert_model_config.hidden_size, self.bert_model_config.hidden_size, batch_first=True)
        self.dropout = nn.Dropout(p=self.bert_model_config.hidden_dropout_prob)
        self.w_weight = nn.Parameter(torch.Tensor(self.bert_model_config.hidden_size, self.bert_model_config.hidden_size))
        self.b_weight = nn.Parameter(torch.Tensor(1, self.bert_model_config.hidden_size))
        self.q_weight = nn.Parameter(torch.Tensor(self.bert_model_config.hidden_size, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(p=self.bert_model_config.hidden_dropout_prob),
            nn.Linear(self.bert_model_config.hidden_size, 1)
        )
        nn.init.uniform_(self.w_weight, -0.1, 0.1)
        nn.init.uniform_(self.b_weight, -0.1, 0.1)
        nn.init.uniform_(self.q_weight, -0.1, 0.1)
        self.classifier.apply(init_weights)
        
    
    def forward(self, inputs):
        input_ids, token_type_ids, attention_mask = inputs['input_ids'], inputs['token_type_ids'], inputs['attention_mask']
        pooler_output = None
        for batch_idx in range(input_ids.shape[0]):
            _, bert_output = self.bert(input_ids=input_ids[batch_idx],
                                      token_type_ids=token_type_ids[batch_idx],
                                      attention_mask=attention_mask[batch_idx])
            bert_output = bert_output.unsqueeze(0)
            if pooler_output is None:
                pooler_output = bert_output
            else:
                pooler_output = torch.cat([pooler_output, bert_output], 0)
        pooler_output = self.dropout(pooler_output).permute(1, 0, 2)
        self.lstm.flatten_parameters()
        hidden_state, _ = self.lstm(pooler_output)
        hidden_state = hidden_state.permute(1, 0, 2)
        alpha_hat = torch.tanh(torch.matmul(hidden_state, self.w_weight) + self.b_weight)
        alpha = F.softmax(torch.matmul(alpha_hat, self.q_weight), dim=1)
        output = torch.sum(hidden_state * alpha, dim=1)
        predictions = self.classifier(output)
        return predictions

# Build Learning-rate schedule

In [37]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_iters):
        self.warmup = warmup
        self.max_num_iters = max_iters
        super().__init__(optimizer)
    
    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]
    
    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

# Define Metric Function

In [43]:
def confusion_matrix(rater_a, rater_b, min_rating=None, max_rating=None):
    assert(len(rater_a) == len(rater_b))
    if min_rating is None:
        min_rating = min(rater_a + rater_b)
    if max_rating is None:
        max_rating = max(rater_a + rater_b)
    
    
    num_ratings = int(max_rating - min_rating + 1)
    conf_mat = [[0 for i in range(num_ratings)]
                for j in range(num_ratings)]
    for a, b in zip(rater_a, rater_b):
        conf_mat[a - min_rating][b - min_rating] += 1
    return conf_mat


def histogram(ratings, min_rating=None, max_rating=None):
    if min_rating is None:
        min_rating = min(ratings)
    if max_rating is None:
        max_rating = max(ratings)
    num_ratings = int(max_rating - min_rating + 1)
    hist_ratings = [0 for x in range(num_ratings)]
    for r in ratings:
        hist_ratings[r - min_rating] += 1
    return hist_ratings


def qwk(rater_a, rater_b, min_rating=None, max_rating=None):

    rater_a = (np.array(rater_a) * 10).astype(int).reshape(-1, )
    rater_b = (np.array(rater_b) * 10).astype(int).reshape(-1, )
    assert(len(rater_a) == len(rater_b))
    if min_rating is None:
        min_rating = min(min(rater_a), min(rater_b))
    if max_rating is None:
        max_rating = max(max(rater_a), max(rater_b))
    
    min_rating = int(math.ceil(min_rating))
    max_rating = int(math.floor(max_rating))
    conf_mat = confusion_matrix(rater_a, rater_b,
                                min_rating, max_rating)

    num_ratings = len(conf_mat)
    num_scored_items = float(len(rater_a))

    hist_rater_a = histogram(rater_a, min_rating, max_rating)
    hist_rater_b = histogram(rater_b, min_rating, max_rating)

    numerator = 0.0
    denominator = 0.0

    for i in range(num_ratings):
        for j in range(num_ratings):
            expected_count = (hist_rater_a[i] * hist_rater_b[j] / num_scored_items)
            if num_ratings == 1:
                num_ratings += 0.0000001
            d = pow(i - j, 2.0) / pow(num_ratings - 1, 2.0)
            numerator += d * conf_mat[i][j] / num_scored_items
            denominator += d * expected_count / num_scored_items

    if denominator <= 0.0000001:
        denominator = 0.0000001
    return 1.0 - numerator / denominator

# Build Pytorch-Lightning Model

In [34]:
class MultiScaleModel(pl.LightningModule):
    
    
    def __init__(self, bert_model_config, score_range, freeze=True):
        super().__init__()
        self.save_hyperparameters()
        self._create_model()
        
        
    def _create_model(self):
        self.doc_scale_encoder = DocumentTokenCombineEncoder(self.hparams.bert_model_config)
        self.seg_scale_encoder = SegmentLSTMAttentionEncoder(self.hparams.bert_model_config)
        
        if self.hparams.freeze:
            for param in self.doc_scale_encoder.bert.parameters():
                param.requires_grad = False
            
            for param in self.seg_scale_encoder.bert.parameters():
                param.requires_grad = False
        
        self.loss_fn = nn.MSELoss()
    
    
    def forward(self, doc=None, seg=None, chunk_size_list=None, **kwargs):
        predictions = self.doc_scale_encoder(inputs=doc)
        for chunk_size in chunk_size_list:
            seg_representations = self.seg_scale_encoder(inputs=seg['seg-%d' % chunk_size])
            predictions = torch.add(predictions, seg_representations)
        
        return predictions
    
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        
        self.lr_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=self.hparams.warmup,
                                                  max_iters=self.hparams.max_iters)
        
        return optimizer
    
    
    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        self.lr_scheduler.step()
        
        
    def _calculate_loss(self, batch, mode='train'):
        labels = batch.pop('labels')
        logits = model(**batch)
        loss = self.loss_fn(logits.squeeze(1), labels)
        qwk = qwk(logits, labels, self.hparams.score_range[0], self.hparams.score_range[1])
        self.log("%s_loss" % mode, loss)
        self.log("%s_qwk" % mode, qwk)
        return loss, qwk
    
    def training_step(self, batch, batch_idx):
        loss, qwk = self._calculate_loss(batch, mode='train')
        return loss, qwk
    
    
    def validation_step(self, batch, batch_idx):
        loss, _ = self._calculate_loss(batch, mode='val')
    
    
    def test_step(self, batch, batch_idx):
        loss, _ = self._calculate_loss(batch, mode='test')
        

In [None]:
def convert_data(df):
    essays = df.essay
    scores = df.score
    return essays, scores

def make_data(data_path):
    data = pd.read_csv('data_path')
    train_data, test_data = trian_test_split(data, test_size=0.4)
    test_data, valid_data = train_test_split(test_data, test_size=0.5)
    train_data = convert_data(train_data)
    test_data = convert_data(test_data)
    valid_data = convert_data(valid_data)
    return train_data, test_data, valid_data
    
def train_ood():
    