In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [16]:
class DAN_conf:
    num_sent = 100
    sent_len = 100
    encoder_dim = 400
    hidden_size = 400
    activation = 'relu'
    dropout = 0.3

    def __init__(self, num_sent, encoder, **kwargs):
        self.num_sent = num_sent
        self.encoder = encoder
        for k, v in kwargs.items():
            setattr(self, k, v)

class DAN(nn.Module):
    def __init__(self,conf):
        super(DAN,self).__init__()
        self.conf = conf
        self.sent_len = conf.sent_len
        self.num_sent = conf.num_sent
        self.encoder = conf.encoder
        del self.conf.encoder
        self.translate = nn.Linear(2 * self.conf.encoder_dim, self.conf.hidden_size)
        self.template = nn.Parameter(torch.zeros((1)), requires_grad=True)
        if self.conf.activation.lower() == "relu".lower():
            self.act = nn.ReLU()
        elif self.conf.activation.lower() == "tanh".lower():
            self.act = nn.Tanh()
        elif self.conf.activation.lower() == "leakyrelu".lower():
            self.act = nn.LeakyReLU()
        self.dropout = nn.Dropout(conf.dropout)

        self.mlp_f = nn.Linear(self.conf.hidden_size, self.conf.hidden_size)
        self.mlp_g = nn.Linear(2*self.conf.hidden_size, self.conf.hidden_size)
        self.mlp_h = nn.Linear(2*self.conf.hidden_size, self.conf.hidden_size)
        self.linear = nn.Linear(self.conf.hidden_size,2)

    def encode_sent(self,inp):
        batch_size,_,_ = inp.shape
        x = inp.view(-1,self.sent_len)

        x_padded_idx = x.sum(dim=1) != 0    
        x_enc = []
        for sub_batch in x[x_padded_idx].split(64):
            x_enc.append(self.encoder(sub_batch)[0])
        x_enc = torch.cat(x_enc, dim=0)

        x_enc_t = torch.zeros((batch_size * self.num_sent, x_enc.size(1))).to(
            self.template.device
        )

        x_enc_t[x_padded_idx] = x_enc
        x_enc_t = x_enc_t.view(batch_size, self.num_sent, -1)
    
        embedded = self.dropout(self.translate(x_enc_t))
        embedded = self.act(embedded)
        embedded = embedded.permute(1, 0, 2)
        return embedded


    def forward(self,x0,x1):
        x0_enc = self.encode_sent(x0).permute(1,0,2)
        x1_enc = self.encode_sent(x1).permute(1,0,2)

        f1 = self.act(self.dropout(self.mlp_f(x0_enc)))
        f2 = self.act(self.dropout(self.mlp_f(x1_enc)))

        score1 = torch.bmm(f1, torch.transpose(f2, 1, 2))
        prob1 = F.softmax(score1.view(-1, self.num_sent)).view(-1, self.num_sent, self.num_sent)

        score2 = torch.transpose(score1.contiguous(), 1, 2)
        score2 = score2.contiguous()

        prob2 = F.softmax(score2.view(-1, self.num_sent)).view(-1, self.num_sent, self.num_sent)

        sent1_combine = torch.cat((x0_enc, torch.bmm(prob1, x1_enc)), 2)
        sent2_combine = torch.cat((x1_enc, torch.bmm(prob2, x0_enc)), 2)

        

        g1 = self.act(self.dropout(self.mlp_g(sent1_combine)))
        g2 = self.act(self.dropout(self.mlp_g(sent2_combine)))

        sent1_output = torch.sum(g1, 1)  
        sent1_output = torch.squeeze(sent1_output, 1)
    
        sent2_output = torch.sum(g2, 1)  
        sent2_output = torch.squeeze(sent2_output, 1)


        input_combine = torch.cat((sent1_output * sent2_output, torch.abs(sent1_output - sent2_output)), 1)
        
        h = self.act(self.dropout(self.mlp_h(input_combine)))
        opt = self.linear(h)
        return opt
    

        

In [17]:
from utils.load_models import load_bilstm_encoder, load_attn_encoder
from utils.helpers import seed_torch

encoder, Lang = load_attn_encoder("SNLI-12")

In [18]:
model_conf = DAN_conf(20, encoder)
model = DAN(model_conf)

model = model.cuda()


In [19]:
x = torch.randint(0,10000,[32,20,100])

In [20]:
opt = model(x.cuda(),x.cuda())

torch.Size([32, 20, 20]) torch.Size([32, 20, 20])
torch.Size([32, 20, 800]) torch.Size([32, 20, 800])
torch.Size([32, 20, 400]) torch.Size([32, 20, 400])
torch.Size([32, 400]) torch.Size([32, 400])


In [21]:
print(opt.shape)

torch.Size([32, 2])


In [19]:
x = x.view(-1,100)

In [20]:
x.shape

torch.Size([640, 100])

In [21]:
x_padded_idx = x.sum(dim=1) != 0

In [22]:
x_padded_idx.shape

torch.Size([640])

In [23]:
x[x_padded_idx].shape

torch.Size([640, 100])

In [1]:
import pandas as pd

In [3]:
import bigjson
with open("dataset/yelp/yelp_academic_dataset_review.json", 'rb') as f:
    data = bigjson.load(f)

In [19]:
import json
def json_reader(filename):
    with open(filename) as f:
        for line in f:
            yield json.loads(line)


In [20]:
data = json_reader("dataset/yelp/yelp_academic_dataset_review.json")

In [21]:
from tqdm import tqdm
data_json = {}
count = 0
for i in tqdm(data,total=8021122):
    data_json[count] = {"text":i["text"],"label":i["stars"]}
    count+=1


8021122it [01:13, 108871.54it/s]                             


In [30]:
labels = [i["label"] for i in data_json.values()]

In [32]:
import numpy as np

In [2]:
import os
import glob
import json
from collections import defaultdict


def create_json():
    """
    """
    wsj = "./dataset/trec/wsj"
    ap = "./dataset/trec/ap"
    wsj_files = glob.glob(wsj + "/*")
    ap_files = glob.glob(ap + "/*")

    docs_json = {}
    errors = 0
    for wsj_file in wsj_files:
        with open(wsj_file, "r") as f:
            txt = f.read()
        docs = [
            i.split("<DOC>")[1]
            for i in filter(lambda x: len(x) > 10, txt.split("</DOC>"))
        ]

        for doc in docs:
            try:
                id = doc.split("<DOCNO>")[1].split("</DOCNO>")[0]
                text = doc.split("<TEXT>")[1].split("</TEXT>")[0]
                docs_json[id] = text
            except:
                errors += 1

    for ap_file in ap_files:
        with open(ap_file, "r", encoding="latin-1") as f:
            txt = f.read()
        docs = [
            i.split("<DOC>")[1]
            for i in filter(lambda x: len(x) > 10, txt.split("</DOC>"))
        ]

        for doc in docs:
            try:
                id = doc.split("<DOCNO>")[1].split("</DOCNO>")[0]
                text = doc.split("<TEXT>")[1].split("</TEXT>")[0]
                docs_json[id] = text
            except:
                errors += 1
    print("Reading APWSJ dataset, Errors : ", errors)

    docs_json = {k.strip(): v.strip() for k, v in docs_json.items()}

    topic_to_doc_file = "./dataset/apwsj/NoveltyData/apwsj.qrels"
    with open(topic_to_doc_file, "r") as f:
        topic_to_doc = f.read()
    topic_doc = [
        (i.split(" 0 ")[1][:-2], i.split(" 0 ")[0]) for i in topic_to_doc.split("\n")
    ]
    topics = "q101, q102, q103, q104, q105, q106, q107, q108, q109, q111, q112, q113, q114, q115, q116, q117, q118, q119, q120, q121, q123, q124, q125, q127, q128, q129, q132, q135, q136, q137, q138, q139, q141"
    topic_list = topics.split(", ")
    filterd_docid = [(k, v) for k, v in topic_doc if v in topic_list]

    def crawl(red_dict, doc, crawled):
        ans = []
        for cdoc in red_dict[doc]:
            ans.append(cdoc)
            if crawled[cdoc] == 0:
                try:
                    red_dict[cdoc] = crawl(red_dict, cdoc, crawled)
                    crawled[cdoc] = 1
                    ans += red_dict[cdoc]
                except:
                    crawled[cdoc] = 1
        return ans

    wf = "./dataset/apwsj/redundancy_list_without_partially_redundant.txt"
    topics_allowed = "q101, q102, q103, q104, q105, q106, q107, q108, q109, q111, q112, q113, q114, q115, q116, q117, q118, q119, q120, q121, q123, q124, q125, q127, q128, q129, q132, q135, q136, q137, q138, q139, q141"
    topics_allowed = topics_allowed.split(", ")
    red_dict = dict()
    allow_partially_redundant = 1
    for line in open("./dataset/apwsj/NoveltyData/redundancy.apwsj.result", "r"):
        tokens = line.split()
        if tokens[2] == "?":
            if allow_partially_redundant == 1:
                red_dict[tokens[0] + "/" + tokens[1]] = [
                    tokens[0] + "/" + i for i in tokens[3:]
                ]
        else:
            red_dict[tokens[0] + "/" + tokens[1]] = [
                tokens[0] + "/" + i for i in tokens[2:]
            ]
    crawled = defaultdict(int)
    for doc in red_dict:
        if crawled[doc] == 0:
            red_dict[doc] = crawl(red_dict, doc, crawled)
            crawled[doc] = 1
    with open(wf, "w") as f:
        for doc in red_dict:
            if doc.split("/")[0] in topics_allowed:
                f.write(
                    " ".join(doc.split("/") + [i.split("/")[1] for i in red_dict[doc]])
                    + "\n"
                )

    write_file = "./dataset/apwsj/novel_list_without_partially_redundant.txt"
    topics = topic_list
    doc_topic_dict = defaultdict(list)

    for i in topic_doc:
        doc_topic_dict[i[0]].append(i[1])
    docs_sorted = (
        open("./dataset/apwsj/NoveltyData/apwsj88-90.rel.docno.sorted", "r")
        .read()
        .splitlines()
    )
    sorted_doc_topic_dict = defaultdict(list)
    for doc in docs_sorted:
        if len(doc_topic_dict[doc]) > 0:
            for t in doc_topic_dict[doc]:
                sorted_doc_topic_dict[t].append(doc)
    redundant_dict = defaultdict(lambda: defaultdict(int))
    for line in open(
        "./dataset/apwsj/redundancy_list_without_partially_redundant.txt", "r"
    ):
        tokens = line.split()
        redundant_dict[tokens[0]][tokens[1]] = 1
    novel_list = []
    for topic in topics:
        if topic in topics_allowed:
            for i in range(len(sorted_doc_topic_dict[topic])):
                if redundant_dict[topic][sorted_doc_topic_dict[topic][i]] != 1:
                    if i > 0:
                        # take at most 5 latest docs in case of novel
                        novel_list.append(
                            " ".join(
                                [topic, sorted_doc_topic_dict[topic][i]]
                                + sorted_doc_topic_dict[topic][max(0, i - 5) : i]
                            )
                        )
    with open(write_file, "w") as f:
        f.write("\n".join(novel_list))

    # Novel cases
    novel_docs = "./dataset/apwsj/novel_list_without_partially_redundant.txt"
    with open(novel_docs, "r") as f:
        novel_doc_list = [i.split() for i in f.read().split("\n")]
    # Redundant cases
    red_docs = "./dataset/apwsj/redundancy_list_without_partially_redundant.txt"
    with open(red_docs, "r") as f:
        red_doc_list = [i.split() for i in f.read().split("\n")]
    red_doc_list = filter(lambda x: len(x) > 0, red_doc_list)
    novel_doc_list = filter(lambda x: len(x) > 0, novel_doc_list)

    dataset = []
    s_not_found = 0
    t_not_found = 0
    for i in novel_doc_list:
        target_id = i[1]
        source_ids = i[2:]
        if target_id in docs_json.keys():
            data_inst = {}
            data_inst["target"] = docs_json[target_id]
            data_inst["source"] = ""
            for source_id in source_ids:
                if source_id in docs_json.keys():
                    data_inst["source"] += docs_json[source_id] + ". \n"
            data_inst["label"] = 1
        else:
            print(target_id)
        if data_inst["source"] != "":
            dataset.append(data_inst)

    for i in red_doc_list:
        target_id = i[1]
        source_ids = i[2:]
        if target_id in docs_json.keys():
            data_inst = {}
            data_inst["target"] = docs_json[target_id]
            data_inst["source"] = ""
            for source_id in source_ids:
                if source_id in docs_json.keys():
                    data_inst["source"] += docs_json[source_id] + ". \n"
            data_inst["label"] = 0
        else:
            print(target_id)
        if data_inst["source"] != "":
            dataset.append(data_inst)

    dataset_json = {}
    for i in range(len(dataset)):
        dataset_json[i] = dataset[i]

    dataset_path = "./dataset/apwsj/apwsj_dataset.json"
    with open(dataset_path, "w") as f:
        json.dump(dataset_json, f)
        


In [4]:
with open('need_these.txt','r') as f:
    data = f.read()

In [5]:
dat = data.split("\n")