In [1]:
from data_loaders.AUS_dataset import AUSDataset, AUSPytorchDataset

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
from project_settings import EOS_TOK, EOC_TOK
import pdb
from tqdm import tqdm, trange
import numpy as np
from models.Model import MeanModel, TruncatModel, NNModel
from project_settings import ExpConfig, DatasetConfig
from utils import chunkify, encode_chunks

## Metric

In [2]:



def micro_contrast(left_left, left_right, right_right, right_left):
    count = (left_left.numel() + left_right.numel() + right_right.numel() + right_left.numel()) / 2
    sum = torch.sum(left_left) + torch.sum(right_right) - torch.sum(left_right) - torch.sum(right_left)
    return sum / count


# def micro_contrast(*arg):
#     count=0
#     sum=0
#     for tensor in arg:
#         count+=tensor.numel()
#         sum+=torch.sum(tensor)
#     return sum/count

def macro_contrast(left_left, left_right, right_right, right_left):
    return (torch.mean(left_left) - torch.mean(left_right) + torch.mean(right_right) - torch.mean(right_left)) / 2



## misc

In [3]:
def transform_chunk_to_dict(chunk):
    n,dim=chunk.size()
    chunk_dict={}
    chunk_dict["input_ids"]=chunk[0]
    chunk_dict["token_type_ids"]=chunk[1]
    chunk_dict["attention_mask"]=chunk[2]
    return chunk_dict

## Benchmark tranformer

Here we test the transformer via a **retrieval task**. We want to pair case **description** with the right **catchphrases** for each case in our legal dataset.

As seen preivously in the **dataset analysis**, the case description are in general very long, average length is **34k** chars,thus around **6k tokens**. However, transformer relying to squared attention only takes 512 tokens. So in this very naive baseline benchmark, we just truncate the sentence at **512th token** 

In [2]:
# cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

ds = AUSDataset()
exp_config = ExpConfig("MeanModel")
train_dataloader = ds.get_data_loader(split='train', batch_size=2, shuffle=True)
val_dataloader = ds.get_data_loader(split='val', batch_size=2, shuffle=True)
test_dataloader = ds.get_data_loader(split='test', batch_size=2, shuffle=True)

tokenizer = AutoTokenizer.from_pretrained(exp_config.uri)
encoder = AutoModel.from_pretrained(exp_config.uri) 
test_micro_contrast = 0  # running loss
test_macro_contrast = 0
nb_test_steps = 0


for step, batch in tqdm(enumerate(test_dataloader)):
    # Add batch to GPU
    "batch = tuple(t.to(device) for t in batch)"
    # Unpack the inputs from our dataloader
    sentences, catchphrases = batch  # len(sentences)=2, len(catchphrases)=2
    # Clear out the gradients (by default they accumulate)

    sentences_a, catchphrase_a = sentences[0], catchphrases[0]
    sentences_b, catchphrase_b = sentences[1], catchphrases[1]
    batch_catchphrase_a = catchphrase_a.split(EOC_TOK)
    batch_catchphrase_b = catchphrase_b.split(EOC_TOK)

    encoded_batch_catchphrase_a = tokenizer(batch_catchphrase_a, truncation=True, return_tensors="pt",
                                            padding='max_length', max_length=512)
    encoded_batch_catchphrase_b = tokenizer(batch_catchphrase_b, truncation=True, return_tensors="pt",
                                            padding='max_length', max_length=512)

    encoded_sentence_a = tokenizer(sentences_a, truncation=True, return_tensors="pt", padding='max_length',
                                   max_length=512)
    encoded_sentence_b = tokenizer(sentences_b, truncation=True, return_tensors="pt", padding='max_length',
                                   max_length=512)
    print("sentences_a length:", len(sentences_a))
    _, batch_catchphrase_embedding_a = encoder(**encoded_batch_catchphrase_a)  # [7, 768]
    _, batch_catchphrase_embedding_b = encoder(**encoded_batch_catchphrase_b)  # [13,768]

    _, sentence_embedding_a = encoder(**encoded_sentence_a)  # [1, 768]
    _, sentence_embedding_b = encoder(**encoded_sentence_b)  # [1, 768]

    left_left = torch.cdist(sentence_embedding_a, batch_catchphrase_embedding_a, p=2.0)  # [1, 768]*[7, 768]=[1, 7]
    left_right = torch.cdist(sentence_embedding_a, batch_catchphrase_embedding_b,
                             p=2.0)  # [1, 768]*[13, 768]=[1, 13]

    right_right = torch.cdist(sentence_embedding_b, batch_catchphrase_embedding_b,
                              p=2.0)  # [1, 768]*[13, 768]=[1, 13]
    right_left = torch.cdist(sentence_embedding_b, batch_catchphrase_embedding_a, p=2.0)  # [1, 768]*[7, 768]=[1, 7]

    nb_test_steps += 1
    test_macro_contrast += macro_contrast(left_left, left_right, right_right, right_left)
    test_micro_contrast += micro_contrast(left_left, left_right, right_right, right_left)

    print("Test micro contrast: {}".format(test_micro_contrast / nb_test_steps))
    print("Test macro contrast: {}".format(test_macro_contrast / nb_test_steps))

0it [00:00, ?it/s]

sentences_a length: 11888





NameError: name 'macro_contrast' is not defined

In [3]:


# def train_contrast_retrieval(data_config, exp_config):


# cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

ds = AUSDataset()
train_dataloader = ds.get_data_loader(split='train', batch_size=2, shuffle=True)
val_dataloader = ds.get_data_loader(split='val', batch_size=2, shuffle=True)
test_dataloader = ds.get_data_loader(split='test', batch_size=2, shuffle=True)

tokenizer = AutoTokenizer.from_pretrained(exp_config.uri)
encoder = AutoModel.from_pretrained(exp_config.uri)

model = NNModel(exp_config)

optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)
print(encoder.parameters())

for name, param in encoder.named_parameters():
    print(name, param.size())
# Store our loss and accuracy for plotting
train_loss_set = []

# Number of training epochs (authors recommend between 2 and 4)
epochs = exp_config.epochs

# trange is a tqdm wrapper around the normal python range
for epoch__ in trange(epochs, desc="Epoch"):

    print("start training")

    # Set our model to training mode (as opposed to evaluation mode)
    encoder.train()

    # Tracking variables
    tr_loss = 0  # running loss
    nb_tr_steps = 0

    # Train the data for one epoch
    for step, batch in tqdm(enumerate(train_dataloader)):
        optimizer.zero_grad()
        # Unpack the inputs from our dataloader
        sentences, catchphrases = batch  # len(sentences)=2, len(catchphrases)=2
        # Clear out the gradients (by default they accumulate)

        sentences_a, catchphrase_a = sentences[0], catchphrases[0]
        sentences_b, catchphrase_b = sentences[1], catchphrases[1]

        batch_catchphrase_a = catchphrase_a.split(EOC_TOK)
        batch_catchphrase_b = catchphrase_b.split(EOC_TOK)

        encoded_batch_catchphrase_a = tokenizer(batch_catchphrase_a, truncation=True, return_tensors="pt",
                                                padding='max_length', max_length=128)
        encoded_batch_catchphrase_b = tokenizer(batch_catchphrase_b, truncation=True, return_tensors="pt",
                                                padding='max_length', max_length=128)

        sentence_indices_a = tokenizer(sentences_a, truncation=True, return_tensors="pt", padding='max_length',
                                       max_length=512*12)
        sentence_indices_b = tokenizer(sentences_b, truncation=True, return_tensors="pt", padding='max_length',
                                       max_length=512*12)

        _, batch_catchphrase_embedding_a = encoder(**encoded_batch_catchphrase_a)  # [7, 768]
        _, batch_catchphrase_embedding_b = encoder(**encoded_batch_catchphrase_b)  # [13,768]

        
        chunk_indices_a=chunkify(sentence_indices_a)
        
        chunk_indices_b=chunkify(sentence_indices_b)
        
#         chunk_embeddings_a=[]
#         chunk_embeddings_b=[]
#         for i, chunk_indice in enumerate(chunk_indices_a):
#             _, chunk_embedding= encoder(**chunk_indices_a[i])  # [1, 768]
#             chunk_embeddings_a.append(torch.squeeze(chunk_embedding))
        
#         chunk_embeddings_a=torch.stack(chunk_embeddings_a,dim=0)

        chunk_embeddings_a=encode_chunks(chunk_indices_a,encoder)
        chunk_embeddings_b=encode_chunks(chunk_indices_b,encoder)
    
        
        #################### Aggregation ######################
        sentence_embedding_a=torch.mean(chunk_embeddings_a,dim=0)
        sentence_embedding_b=torch.mean(chunk_embeddings_b,dim=0)
        

        triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)

        anchor_a = sentence_embedding_a.unsqueeze(0)
        anchor_b = sentence_embedding_b.unsqueeze(0)
        batch_train_loss_set=[]
        for catchphrase_embedding_a in tqdm(batch_catchphrase_embedding_a):
            for catchphrase_embedding_b in batch_catchphrase_embedding_b:
                positive = catchphrase_embedding_a.unsqueeze(0)
                negative = catchphrase_embedding_b.unsqueeze(0)
                loss = triplet_loss(anchor_a, positive, negative)
                batch_train_loss_set.append(loss.unsqueeze(0))
                loss = triplet_loss(anchor_b, negative, positive)
                batch_train_loss_set.append(loss.unsqueeze(0))
        batch_loss=torch.mean(torch.cat(batch_train_loss_set))
        batch_loss.backward()
        optimizer.step()
        # Update tracking variables
        tr_loss += batch_loss.item()
        nb_tr_steps += 1

        print("Train loss: {}".format(tr_loss/nb_tr_steps))





Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

<generator object Module.parameters at 0x12659a850>
start training



0it [00:00, ?it/s][A

100%|██████████| 7/7 [00:00<00:00, 565.79it/s]

1it [00:02,  2.26s/it][A

Train loss: 1.0587562322616577




100%|██████████| 4/4 [00:00<00:00, 674.11it/s]

2it [00:04,  2.19s/it][A

Train loss: 1.423868715763092


2it [00:04,  2.25s/it]
Epoch:   0%|          | 0/10 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [2]:
if __name__ == '__main__':
    ds = AUSDataset()

    # ds.save_processed_splits()

    test_dataloader = ds.get_data_loader(split='test', batch_size=2, shuffle=True)
    # print(test_dl.batch_size)
    # for i in test_dl:
    #     print(len(i[0]),len(i[1]))
    data_config = DatasetConfig("AUS")
    exp_config = ExpConfig("MeanModel")
    train_contrast_retrieval(data_config, exp_config)

NameError: name 'train_contrast_retrieval' is not defined