## Importing Required Components

In [None]:
import math
import torch
import pickle
import itertools
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AdamW, AlbertForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, random_split

## Preparing the Model and Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("albert-xlarge-v2")
model = AlbertForSequenceClassification.from_pretrained("albert-xlarge-v2")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device) # Moving the model to GPU if using one

num_of_sentences_to_order = 5

## Converting Raw Data to Torch Dataset

In [None]:
def read_pickle_data(addr, mode="rb"):
    with open(addr, mode) as file:
        data = pickle.load(file)

    return data

def write_pickle_data(addr, data, mode="wb"):
    with open(addr, mode) as file:
        pickle.dump(data, file)

def create_pairs_from_orig_data(orig_data):
    pairs = []

    for key in orig_data: # Labelling 0 if the second sentence is the next sentence of the first one, and 1 otherwise
        example = orig_data[key]
        pair = sorted(zip(example[1], example[0]))
        # Looping over all permutations of selecting two sentences from all sentences in a set
        for i, j in itertools.permutations(range(num_of_sentences_to_order), 2):
            if j == i + 1: # As 0s are much less than 1s, adding 0s twice can make the dataset more balanced
                pairs.append([pair[i][1].lower(), pair[j][1].lower(), 0])
                pairs.append([pair[i][1].lower(), pair[j][1].lower(), 0])
            else:
                pairs.append([pair[i][1].lower(), pair[j][1].lower(), 1])

    return pairs


def create_dataset_from_pairs(pairs):
    input_ids, token_type_ids, attention_mask, labels = [], [], [], []

    for pair in pairs:
        sentence1, sentence2, label = pair
        encoding = tokenizer.encode_plus(sentence1, sentence2,
                                         add_special_tokens=True,
                                         max_length=128,
                                         padding="max_length",
                                         return_tensors="pt")

        input_ids.append(encoding["input_ids"])
        token_type_ids.append(encoding["token_type_ids"])
        attention_mask.append(encoding["attention_mask"])
        labels.append(label)

    input_ids = torch.cat(input_ids, dim=0)
    token_type_ids = torch.cat(token_type_ids, dim=0)
    attention_mask = torch.cat(attention_mask, dim=0)
    labels = torch.tensor(labels)

    ds = TensorDataset(input_ids, token_type_ids, attention_mask, labels)

    return ds

def get_train_val_datasets(ds, val_size=512):
    train_ds, val_ds = random_split(ds, [len(ds) - val_size, val_size])
    return train_ds, val_ds

def get_datasets():
    train_data_orig = read_pickle_data("data/train.pickle")
    pairs = create_pairs_from_orig_data(train_data_orig)
    ds = create_dataset_from_pairs(pairs)
    # write_pickle_data("data/ds.pickle", ds)
    # ds = read_pickle_data("data/ds.pickle")
    train_ds, val_ds = get_train_val_datasets(ds)

    return train_ds, val_ds

## Getting Train and Validation Datasets

In [None]:
train_ds, val_ds = get_datasets()

## Configuring Train and Validation Dataloaders

In [None]:
batch_size = 8
num_samples = 17000

train_dataloader = DataLoader(train_ds, sampler=RandomSampler(train_ds, num_samples=num_samples), batch_size=batch_size)
val_dataloader = DataLoader(val_ds, sampler=SequentialSampler(val_ds), batch_size=batch_size)

## Configuring Optimizer and Scheduler

In [None]:
epochs = 15
learning_rate = 8e-6

optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8, no_deprecation_warning=True)
scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * epochs)
best_val_loss = math.inf

## Checking Dataloaders Size

In [None]:
print(len(train_dataloader))
print(len(val_dataloader))

## Training and Evaluating the Model

In [None]:
train_losses = []
val_losses = []
accum_num = 4 # Number of Epochs to Accumulate Gradients

for epoch in range(epochs):
    print(f"Epoch: {epoch + 1}")
    # Train Phase
    train_loss = 0
    model.train()
    model.zero_grad()

    for step, batch in enumerate(train_dataloader):
        batch_input_ids = batch[0].to(device)
        batch_token_type_ids = batch[1].to(device)
        batch_attention_mask = batch[2].to(device)
        batch_labels = batch[3].to(device)

        res = model(batch_input_ids,
                    token_type_ids=batch_token_type_ids,
                    attention_mask=batch_attention_mask,
                    labels=batch_labels,
                    return_dict=True)

        batch_loss = res.loss
        train_loss += batch_loss.item()
        batch_loss.backward()

        if ((step + 1) % accum_num == 0) or (step + 1 == len(train_dataloader)):
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            model.zero_grad()

        if step % 1000 == 0:
            print(f"Step {step} Passed.")

    avg_train_loss = train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)

    # Evaluation Phase
    val_loss = 0
    model.eval()
    for batch in val_dataloader:
        batch_input_ids = batch[0].to(device)
        batch_token_type_ids = batch[1].to(device)
        batch_attention_mask = batch[2].to(device)
        batch_labels = batch[3].to(device)

        with torch.no_grad():
            res = model(batch_input_ids,
                        token_type_ids=batch_token_type_ids,
                        attention_mask=batch_attention_mask,
                        labels=batch_labels,
                        return_dict=True)

            batch_loss = res.loss

            val_loss += batch_loss.item()

    avg_val_loss = val_loss / len(val_dataloader)
    val_losses.append(avg_val_loss)

    print(f"Epoch Validation Loss: {avg_val_loss}")
    print(f"Epoch Train Loss: {avg_train_loss}")

    if avg_val_loss < best_val_loss: # Checkpoint Model with the Least Validation Loss Value
        best_val_loss = avg_val_loss
        model.save_pretrained("models/best_model")

# Plotting Train and Validation Loss
plt.plot(train_losses)
plt.plot(val_losses)
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')
plt.show()

## Reading Test Data

In [None]:
test_data_orig = read_pickle_data("data/test.pickle")

## Reordering Test Data and Generating the Output

In [None]:
result_rows = []
model.eval()
for index in test_data_orig:
    # Creating an N x N matrix. Matrix[i][j] is the output of the model for the ith sentence followed by the jth sentence (class 0 in the classification scheme)
    prob_mat = []
    example = test_data_orig[index][0]
    for i in range(num_of_sentences_to_order):
        row = []
        for j in range(num_of_sentences_to_order):
            if i == j:
                row.append(0)
                continue
            else:
                encoding = tokenizer.encode_plus(example[i], example[j],
                                                 add_special_tokens=True,
                                                 max_length=128,
                                                 padding="max_length",
                                                 return_tensors="pt")
                input_ids = encoding["input_ids"].to(device)
                token_type_ids = encoding["token_type_ids"].to(device)
                attention_mask = encoding["attention_mask"].to(device)

                with torch.no_grad():
                    res = model(input_ids,
                                token_type_ids=token_type_ids,
                                attention_mask=attention_mask,
                                return_dict=True)

                    row.append(res.logits.detach().cpu().numpy()[0][0]) # 0 is the class that the second sequence comes after the first one

        prob_mat.append(row)

    # Finding the permutation with the highest sum of output model for each pair of consecutive sentences in it
    max_prob = -math.inf
    max_prob_permutation = None
    for permutation in itertools.permutations(range(num_of_sentences_to_order)):
        prob_sum = 0
        for i in range(num_of_sentences_to_order - 1):
            prob_sum += prob_mat[permutation[i]][permutation[i+1]]

        if prob_sum > max_prob:
            max_prob = prob_sum
            max_prob_permutation = permutation

    # Putting indexes in order
    result_row = [index]
    for i in range(num_of_sentences_to_order):
        result_row.append(max_prob_permutation.index(i))

    result_rows.append(result_row)


## Writing Results to Disk

In [None]:
columns = ["id"]
for i in range(len(num_of_sentences_to_order)):
    columns.append(f"index_{i+1}")
pd.DataFrame(result_rows, columns=columns).to_csv("results.csv", index=False)