In [None]:
!wget -nc https://nlp.stanford.edu/projects/snli/snli_1.0.zip
!unzip snli_1.0.zip
!rm snli_1.0.zip

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
keys = ['train', 'test', 'dev']

In [None]:
import json

In [None]:
data = {k:[] for k in keys}
for k in keys :
    for line in open('snli_1.0/snli_1.0_' + k + '.jsonl').readlines() :
        data[k].append(json.loads(line))

In [None]:
import spacy
nlp = spacy.load("en_core_web_sm")
from tqdm import tqdm
import pandas as pd

def collect_data(docs, exp_split):
    
    annotations = []
    
    labels = {"contradiction":0, "neutral":1, "entailment":2}
    
    cur_caption = None
    
    for doc in tqdm(docs) :
        
        if doc["gold_label"] in labels:
            
            if doc["captionID"] == cur_caption:
                
                _i += 1
                
            else:
                
                _i = 0
                
            cur_caption = doc["captionID"]
            
            document = doc["sentence1"]
            query = doc["sentence2"]
    
            annotations.append({
                'annotation_id' : doc["captionID"] + "_" + str(_i),
                'doc_id' : doc["captionID"], 
                'document' : document,
                'label_id' : doc["gold_label"],
                'label' : labels[doc["gold_label"]],
                'query' : query,
                'exp_split' : exp_split
            })
    
    
    return pd.DataFrame(annotations)

In [None]:
train = collect_data(data["train"], "train")
dev = collect_data(data["dev"], "dev")
test = collect_data(data["test"], "test")

len(train), len(dev), len(test)

In [None]:
train["lengths"] = train["query"].apply(lambda x : len(x.split()))
dev["lengths"] = dev["query"].apply(lambda x : len(x.split()))
test["lengths"] = test["query"].apply(lambda x : len(x.split()))

In [None]:
train = train[train["lengths"] > 1].drop(columns = "lengths")
dev = dev[dev["lengths"] > 1].drop(columns = "lengths")
test = test[test["lengths"] > 1].drop(columns = "lengths")

In [None]:
len(train), len(dev), len(test)

In [None]:
import os
os.makedirs('data', exist_ok=True)
train.to_csv("data/train.csv", index = False)
dev.to_csv("data/dev.csv", index = False)
test.to_csv("data/test.csv", index = False)

In [None]:
!rm -r __MACOSX/
!rm -r snli_1.0/