# Load Data

In [2]:
! pip install pytorch_pretrained_bert
! pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_pretrained_bert
  Downloading pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123 kB)
[K     |████████████████████████████████| 123 kB 6.9 MB/s 
Collecting boto3
  Downloading boto3-1.26.20-py3-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 45.9 MB/s 
Collecting botocore<1.30.0,>=1.29.20
  Downloading botocore-1.29.20-py3-none-any.whl (10.2 MB)
[K     |████████████████████████████████| 10.2 MB 19.8 MB/s 
[?25hCollecting jmespath<2.0.0,>=0.7.1
  Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)
Collecting s3transfer<0.7.0,>=0.6.0
  Downloading s3transfer-0.6.0-py3-none-any.whl (79 kB)
[K     |████████████████████████████████| 79 kB 5.2 MB/s 
Collecting urllib3<1.27,>=1.25.4
  Downloading urllib3-1.26.13-py2.py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 11.7 MB/s 
  Downloading urllib3-1.25.11-py2.py3-none-any.w

In [3]:
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.insert(0, '/content/drive/MyDrive/Colab Notebooks/Capstone')

import os
import pandas as pd
import numpy as np

from utils import read_conll_file, read_data

from torchmetrics.functional.classification import multiclass_f1_score, multiclass_precision, multiclass_recall, multiclass_accuracy

data_dir = "/content/drive/MyDrive/Colab Notebooks/Capstone/data/gweb_sancl"
wsj_dir = os.path.join(data_dir, "pos_fine", "wsj")
model_dir = "/content/drive/MyDrive/Colab Notebooks/Capstone/model"

Mounted at /content/drive


In [3]:
wsj_train_file = os.path.join(wsj_dir, "gweb-wsj-train.conll")
wsj_dev_file = os.path.join(wsj_dir, "gweb-wsj-dev.conll")
wsj_test_file = os.path.join(wsj_dir, "gweb-wsj-test.conll")

In [None]:
wsj_train_word_lst, wsj_train_tag_lst, wsj_train_tag_set = read_data(wsj_train_file)
wsj_dev_word_lst, wsj_dev_tag_lst, wsj_dev_tag_set = read_data(wsj_dev_file)
wsj_test_word_lst, wsj_test_tag_lst, wsj_test_tag_set = read_data(wsj_test_file)

The number of samples: 30060
The number of tags 48
The number of samples: 1336
The number of tags 45
The number of samples: 1640
The number of tags 45


In [None]:
wsj_tags = wsj_train_tag_set + wsj_dev_tag_set + wsj_test_tag_set
wsj_tags = sorted(list(set(wsj_tags)))
wsj_tags = ["<pad>"] + wsj_tags
tag2idx = {tag:idx for idx, tag in enumerate(wsj_tags)}
idx2tag = {idx:tag for idx, tag in enumerate(wsj_tags)}
print(len(wsj_tags))

49


# Build Model

In [7]:
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report

import os
from tqdm import tqdm_notebook as tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from pytorch_pretrained_bert import BertTokenizer

ModuleNotFoundError: ignored

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

100%|██████████| 213450/213450 [00:00<00:00, 844408.70B/s]


In [None]:
class PosDataset(data.Dataset):
    def __init__(self, word_lst, tag_lst):
        sents, tags_li = [], [] # list of lists
        for i in range(len(word_lst)):
            sents.append(["[CLS]"] + word_lst[i] + ["[SEP]"])
            tags_li.append(["<pad>"] + tag_lst[i] + ["<pad>"])
        self.sents, self.tags_li = sents, tags_li

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        words, tags = self.sents[idx], self.tags_li[idx] # words, tags: string list

        # We give credits only to the first piece.
        x, y = [], [] # list of ids
        is_heads = [] # list. 1: the token is the first piece of a word
        for w, t in zip(words, tags):
            tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w]
            xx = tokenizer.convert_tokens_to_ids(tokens)

            is_head = [1] + [0]*(len(tokens) - 1)

            t = [t] + ["<pad>"] * (len(tokens) - 1)  # <PAD>: no decision
            yy = [tag2idx[each] for each in t]  # (T,)

            x.extend(xx)
            is_heads.extend(is_head)
            y.extend(yy)

        assert len(x)==len(y)==len(is_heads), "len(x)={}, len(y)={}, len(is_heads)={}".format(len(x), len(y), len(is_heads))

        # seqlen
        seqlen = len(y)

        # to string
        words = " ".join(words)
        tags = " ".join(tags)
        return words, x, is_heads, tags, y, seqlen


In [None]:
def pad(batch):
    '''Pads to the longest sample'''
    f = lambda x: [sample[x] for sample in batch]
    words = f(0)
    is_heads = f(2)
    tags = f(3)
    seqlens = f(-1)
    maxlen = np.array(seqlens).max()

    f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: <pad>
    x = f(1, maxlen)
    y = f(-2, maxlen)


    f = torch.LongTensor

    return words, f(x), is_heads, tags, f(y), seqlens

In [None]:
from pytorch_pretrained_bert import BertModel

In [None]:
class Net(nn.Module):
    def __init__(self, vocab_size=None):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')

        self.fc = nn.Linear(768, vocab_size)
        self.device = device

    def forward(self, x, y):
        '''
        x: (N, T). int64
        y: (N, T). int64
        '''
        x = x.to(device)
        y = y.to(device)
        
        if self.training:
            self.bert.train()
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers, _ = self.bert(x)
                enc = encoded_layers[-1]
        
        logits = self.fc(enc)
        y_hat = logits.argmax(-1)
        return logits, y, y_hat

In [None]:
def train(model, iterator, optimizer, criterion):
    model.train()
    for i, batch in enumerate(iterator):
        words, x, is_heads, tags, y, seqlens = batch
        _y = y # for monitoring
        optimizer.zero_grad()
        logits, y, _ = model(x, y) # logits: (N, T, VOCAB), y: (N, T)

        logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)
        y = y.view(-1)  # (N*T,)

        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i%10==0: # monitoring
            print("step: {}, loss: {}".format(i, loss.item()))

In [None]:
def eval(model, iterator, average="weighted"):
    model.eval()

    Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []

    pred_lst = []
    true_lst = []

    with torch.no_grad():
        for i, batch in enumerate(iterator):
            words, x, is_heads, tags, y, seqlens = batch

            _, _, y_hat = model(x, y)  # y_hat: (N, T)

            for s in y_hat.cpu().numpy().tolist():
              pred_lst.extend(s)
            for s in y.numpy().tolist():
              true_lst.extend(s)

    precision_value = multiclass_precision(
            torch.tensor(pred_lst), torch.tensor(true_lst), num_classes=len(wsj_tags), ignore_index=0, 
            average=average)   
    recall_value = multiclass_recall(
            torch.tensor(pred_lst), torch.tensor(true_lst), num_classes=len(wsj_tags), ignore_index=0, 
            average=average)   
    f1_value = multiclass_f1_score(
            torch.tensor(pred_lst), torch.tensor(true_lst), num_classes=len(wsj_tags), ignore_index=0, 
            average=average)   
    acc = multiclass_accuracy(
        torch.tensor(pred_lst), torch.tensor(true_lst), num_classes=len(wsj_tags), ignore_index=0, 
        average=average)    


    return precision_value, recall_value, f1_value, acc

In [None]:
model = Net(vocab_size=len(tag2idx))
model.to(device)
model = nn.DataParallel(model)

100%|██████████| 404400730/404400730 [00:22<00:00, 18223503.74B/s]


In [None]:
train_dataset = PosDataset(wsj_train_word_lst, wsj_train_tag_lst)
eval_dataset = PosDataset(wsj_test_word_lst, wsj_test_tag_lst)

train_iter = data.DataLoader(dataset=train_dataset,
                             batch_size=8,
                             shuffle=True,
                             num_workers=1,
                             collate_fn=pad)
test_iter = data.DataLoader(dataset=eval_dataset,
                             batch_size=8,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=pad)

optimizer = optim.Adam(model.parameters(), lr = 0.0001)

criterion = nn.CrossEntropyLoss(ignore_index=0)

In [None]:
# train(model, train_iter, optimizer, criterion)
# eval(model, test_iter)

# Save Model

In [None]:
model_file = os.path.join(model_dir, "base_model.pt")
# torch.save(model.state_dict(), model_file)

## Load Model

In [None]:
model = Net(vocab_size=len(tag2idx))
model.to(device)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(model_file))
wsj_precision_value, wsj_recall_value, wsj_f1_value, wsj_acc_value = eval(model, test_iter)
print(wsj_precision_value, wsj_recall_value, wsj_f1_value, wsj_acc_value)

tensor(0.9771) tensor(0.9743) tensor(0.9751) tensor(0.9743)


# Self Training

In [None]:
def filter_tag(process_words, process_tags, label_tags_set=wsj_tags):
  new_words = []
  new_tags = []
  for words, tags in zip(process_words, process_tags):
    w_lst = []
    t_lst = []
    for i, t in enumerate(tags):
      if t in label_tags_set:
        w_lst.append(words[i])
        t_lst.append(tags[i])

    if w_lst:
      new_words.append(w_lst)
      new_tags.append(t_lst)
  print("after filter tag", len(new_words))
  return new_words, new_tags

In [None]:
file_name_lst = ["answers", "emails", "newsgroups", "reviews", "weblogs"]

In [None]:
domain = "weblogs"
domain_dir = os.path.join(data_dir, "pos_fine", f"{domain}")
domain_dev_file = os.path.join(domain_dir, f"gweb-{domain}-dev.conll")
domain_test_file = os.path.join(domain_dir, f"gweb-{domain}-test.conll")

In [None]:
domain_dev_word_lst, domain_dev_tag_lst, domain_dev_tag_set = read_data(domain_dev_file)
domain_test_word_lst, domain_test_tag_lst, domain_test_tag_set = read_data(domain_test_file)
domain_dev_word_lst, domain_dev_tag_lst = filter_tag(domain_dev_word_lst, domain_dev_tag_lst)  
domain_test_word_lst, domain_test_tag_lst = filter_tag(domain_test_word_lst, domain_test_tag_lst)

The number of samples: 1016
The number of tags 47
The number of samples: 1015
The number of tags 49
after filter tag 1016
after filter tag 974


In [None]:
domain_precision_value_lst = []
domain_recall_value_lst = []
domain_f1_value_lst = []
domain_acc_value_lst = []

In [None]:
domain_test_dataset = PosDataset(domain_test_word_lst, domain_test_tag_lst)

domain_test_iter = data.DataLoader(dataset=domain_test_dataset,
                             batch_size=8,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=pad)

domain_precision_value, domain_recall_value, domain_f1_value, domain_acc_value = eval(model, domain_test_iter)

domain_precision_value_lst.append(domain_precision_value)
domain_recall_value_lst.append(domain_recall_value)
domain_f1_value_lst.append(domain_f1_value)
domain_acc_value_lst.append(domain_acc_value)

In [None]:
class PosDataset_new(data.Dataset):
    def __init__(self, word_lst, tag_lst):
        self.word_lst, self.tag_lst = word_lst, tag_lst

    def __len__(self):
      return len(self.word_lst)

    def __getitem__(self, idx):
      words, tags = self.word_lst[idx], self.tag_lst[idx] # words, tags: string list
      assert len(words)==len(tags)
        # seqlen
      seqlen = len(words)

      return words, tags, seqlen

def pad_new(batch):
    '''Pads to the longest sample'''
    f = lambda x: [sample[x] for sample in batch]
    words = f(0)
    tags = f(1)
    seqlens = f(-1)
    maxlen = np.array(seqlens).max()

    f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: <pad>
    x = f(0, maxlen)
    y = f(1, maxlen)

    f = torch.LongTensor

    return f(x), f(y), seqlens

def train_new(model, iterator, optimizer, criterion):
    model.train()
    for i, batch in enumerate(iterator):
        x, y, seqlens = batch
        
        optimizer.zero_grad()
        logits, y, _ = model(x, y) # logits: (N, T, VOCAB), y: (N, T)

        logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)
        y = y.view(-1)  # (N*T,)

        loss = criterion(logits, y)
        loss.backward()

        optimizer.step()

        if i%10==0: # monitoring
            print("step: {}, loss: {}".format(i, loss.item()))

In [None]:
def gen_pseudo_data(model, domain_dev_iter, topn=300, initial=True):
  model.eval()

  LLD = []
  MEAN_PROB = []
  new_x_lst = []
  new_y_lst = []
  acc_lst = []

  if initial:
    with torch.no_grad():
        for i, batch in enumerate(domain_dev_iter):

          _, x, _, _, y, _ = batch
          # When calculating the length of sentences, ignore <pad>
          sen_len = y.bool().sum(axis=1)

          logits, _, y_hat = model(x, y)  # y_hat: (N, T)

          # Save prediction as new training dataset
          softmax_value = torch.softmax(logits, dim=2)
          max_prob = torch.amax(softmax_value, dim=2)

          # Rank by mean probability
          res_prob = y.bool().to(device) * max_prob.to(device)
          sum_prob = res_prob.sum(axis=1)
          mean_prob = sum_prob / sen_len.to(device)
          MEAN_PROB.extend(mean_prob.tolist())
          
          new_x_lst.extend(x.tolist())
          new_y_lst.extend(y_hat.tolist())

          # Calculate the accuracy for each sentences, ignore 0
          batch_acc = multiclass_accuracy(
              torch.tensor(y_hat).to(device), torch.tensor(y).to(device), num_classes=len(wsj_tags), 
              ignore_index=0, average="micro", multidim_average="samplewise")
          acc_lst.extend(batch_acc.tolist())
          

  else:
    with torch.no_grad():
        for i, batch in enumerate(domain_dev_iter):

          x, y, seqlens = batch
          sen_len = y.bool().sum(axis=1)

          logits, _, y_hat = model(x, y)  # y_hat: (N, T)

          # Save prediction as new training dataset
          softmax_value = torch.softmax(logits, dim=2)
          max_prob = torch.amax(softmax_value, dim=2)

          # Rank by mean probability
          res_prob = y.bool().to(device) * max_prob.to(device)
          sum_prob = res_prob.sum(axis=1)
          mean_prob = sum_prob / sen_len.to(device)
          MEAN_PROB.extend(mean_prob.tolist())
          
          new_x_lst.extend(x.tolist())
          new_y_lst.extend(y_hat.tolist())

          # Calculate the accuracy for each sentences, ignore 0
          batch_acc = multiclass_accuracy(
              torch.tensor(y_hat).to(device), torch.tensor(y).to(device), num_classes=len(wsj_tags), 
              ignore_index=0, average="micro", multidim_average="samplewise")
          acc_lst.extend(batch_acc.tolist())

  ind = list(range(len(MEAN_PROB)))
  ind = [x for _, x in sorted(zip(MEAN_PROB, ind), reverse=True)]
  prob_lst = [prob for prob, _ in sorted(zip(MEAN_PROB, ind), reverse=True)]

  select_ind = ind[: topn] # The index of topn sentences
  not_select_ind = ind[topn: ]

  new_train_x = [new_x_lst[i] for i in select_ind]
  new_train_y = [new_y_lst[i] for i in select_ind]

  remain_train_x = [new_x_lst[i] for i in not_select_ind]
  remain_train_y = [new_y_lst[i] for i in not_select_ind]

  new_prob = prob_lst[: topn]
  remain_prob = prob_lst[topn: ]
  new_acc = [acc_lst[i] for i in select_ind]
  remain_acc = [acc_lst[i] for i in not_select_ind]


  return new_train_x, new_train_y, remain_train_x, remain_train_y, new_acc, remain_acc, new_prob, remain_prob

In [None]:
acc_lst = []
prob_lst = []

factor_list = [1, 2, 5, 10, 20]
factor = factor_list[3] #  to be modified
topn = round(factor * len(domain_dev_word_lst) / 100)

i = 0
while len(domain_dev_word_lst) >= topn:
  i += 1
  print("\nLoop", i)
  print("domain_dev_word_lst", len(domain_dev_word_lst))

  if i == 1:
    domain_dev_dataset = PosDataset(domain_dev_word_lst, domain_dev_tag_lst)

    domain_dev_iter = data.DataLoader(dataset=domain_dev_dataset,
                                batch_size=8,
                                shuffle=False,
                                num_workers=1,
                                collate_fn=pad)
  else:
    domain_dev_dataset = PosDataset_new(domain_dev_word_lst, domain_dev_tag_lst)

    domain_dev_iter = data.DataLoader(dataset=domain_dev_dataset,
                                batch_size=8,
                                shuffle=True,
                                num_workers=1,
                                collate_fn=pad_new)
  
  initial = True if i==1 else False
  top_words_ids, top_tags_ids, domain_dev_word_lst, domain_dev_tag_lst, new_acc, remain_acc, new_prob, remain_prob = gen_pseudo_data(model, domain_dev_iter, topn, initial)

  # Revert ids to words
  top_words = []
  top_tags = []
  for t in range(len(top_words_ids)):
    word_ids = tokenizer.convert_ids_to_tokens(top_words_ids[t])
    tag_ids = list(map(idx2tag.get, top_tags_ids[t]))
    words = []
    tags = []
    for k, w in enumerate(word_ids):
      if w == '[CLS]':
        pass
      elif w == '[SEP]':
        break
      else:
        words.append(w)
        tags.append(tag_ids[k])
    top_words.append(words)
    top_tags.append(tags)

  new_train_dataset = PosDataset(wsj_train_word_lst+top_words, wsj_train_tag_lst+top_tags)
  new_train_iter = data.DataLoader(dataset=new_train_dataset,
                              batch_size=8,
                              shuffle=True,
                              num_workers=1,
                              collate_fn=pad)

  print("Train from scratch...")
  model = Net(vocab_size=len(tag2idx))
  model.to(device)
  model = nn.DataParallel(model)

  optimizer = optim.Adam(model.parameters(), lr = 0.0001)
  criterion = nn.CrossEntropyLoss(ignore_index=0)

  train(model, new_train_iter, optimizer, criterion)

  domain_precision_value, domain_recall_value, domain_f1_value, domain_acc_value = eval(model, domain_test_iter)

  domain_precision_value_lst.append(domain_precision_value)
  domain_recall_value_lst.append(domain_recall_value)
  domain_f1_value_lst.append(domain_f1_value)
  domain_acc_value_lst.append(domain_acc_value)

  acc_lst.append(new_acc)
  prob_lst.append(new_prob)



Loop 1
domain_dev_word_lst 1016


  torch.tensor(y_hat).to(device), torch.tensor(y).to(device), num_classes=len(wsj_tags),


Train from scratch...
step: 0, loss: 3.9304182529449463
step: 10, loss: 1.8634181022644043
step: 20, loss: 0.9408775568008423
step: 30, loss: 0.4345119297504425
step: 40, loss: 0.2426508069038391
step: 50, loss: 0.2724077105522156
step: 60, loss: 0.15933601558208466
step: 70, loss: 0.1676059365272522
step: 80, loss: 0.21330897510051727
step: 90, loss: 0.1347602903842926
step: 100, loss: 0.07239922881126404
step: 110, loss: 0.10935758799314499
step: 120, loss: 0.23317985236644745
step: 130, loss: 0.12763506174087524
step: 140, loss: 0.11532191932201385
step: 150, loss: 0.2229335755109787
step: 160, loss: 0.18676137924194336
step: 170, loss: 0.0820370465517044
step: 180, loss: 0.21269381046295166
step: 190, loss: 0.08923698216676712
step: 200, loss: 0.18482738733291626
step: 210, loss: 0.09691624343395233
step: 220, loss: 0.08529931306838989
step: 230, loss: 0.06023191660642624
step: 240, loss: 0.13161224126815796
step: 250, loss: 0.1318320780992508
step: 260, loss: 0.1451229453086853
st

  torch.tensor(y_hat).to(device), torch.tensor(y).to(device), num_classes=len(wsj_tags),


Train from scratch...
step: 0, loss: 3.947545051574707
step: 10, loss: 1.81839120388031
step: 20, loss: 0.7410621643066406
step: 30, loss: 0.4402027726173401
step: 40, loss: 0.3797014653682709
step: 50, loss: 0.19948604702949524
step: 60, loss: 0.12781402468681335
step: 70, loss: 0.28110232949256897
step: 80, loss: 0.180606409907341
step: 90, loss: 0.12840670347213745
step: 100, loss: 0.20387208461761475
step: 110, loss: 0.15809954702854156
step: 120, loss: 0.12034683674573898
step: 130, loss: 0.1773402839899063
step: 140, loss: 0.16111023724079132
step: 150, loss: 0.2027190774679184
step: 160, loss: 0.10519345849752426
step: 170, loss: 0.13140225410461426
step: 180, loss: 0.16128839552402496
step: 190, loss: 0.2089368849992752
step: 200, loss: 0.07861389964818954
step: 210, loss: 0.09260715544223785
step: 220, loss: 0.08267010003328323
step: 230, loss: 0.08966845273971558
step: 240, loss: 0.16375243663787842
step: 250, loss: 0.15278558433055878
step: 260, loss: 0.06763710081577301
ste

In [None]:
print(domain_precision_value_lst)
print(domain_recall_value_lst)
print(domain_f1_value_lst)
print(domain_acc_value_lst)

print(acc_lst)
print(prob_lst)

[tensor(0.9466), tensor(0.9481), tensor(0.9471), tensor(0.9505), tensor(0.9502), tensor(0.9529), tensor(0.9476), tensor(0.9483), tensor(0.9488), tensor(0.9455)]
[tensor(0.9421), tensor(0.9422), tensor(0.9421), tensor(0.9466), tensor(0.9459), tensor(0.9471), tensor(0.9427), tensor(0.9436), tensor(0.9443), tensor(0.9394)]
[tensor(0.9392), tensor(0.9397), tensor(0.9402), tensor(0.9441), tensor(0.9432), tensor(0.9450), tensor(0.9403), tensor(0.9416), tensor(0.9427), tensor(0.9363)]
[tensor(0.9421), tensor(0.9422), tensor(0.9421), tensor(0.9466), tensor(0.9459), tensor(0.9471), tensor(0.9427), tensor(0.9436), tensor(0.9443), tensor(0.9394)]
[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.95652174949646, 1.0, 0.8947368264198303, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.931034505367279, 1.0, 1.0, 1.0, 1.0, 1.0, 1.

In [6]:
import pandas as pd

In [None]:
test_metric = pd.DataFrame({
    "Loop": list(range(len(domain_precision_value_lst))) * 3,
    "metric": ["precision"]*len(domain_precision_value_lst) + ["recall"]*len(domain_precision_value_lst) + ["f1"]*len(domain_precision_value_lst),
    "value": domain_precision_value_lst + domain_recall_value_lst + domain_f1_value_lst
})

In [4]:
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
import plotly
import plotly.express as px
import plotly.graph_objects as go

In [None]:
fig = px.line(test_metric, x="Loop", y="value", color='metric', markers=True)
fig.show()

In [None]:
file_name = f"topn_{factor}%_domain_{domain}"
scratch_model_dir = "/content/drive/MyDrive/Colab Notebooks/Capstone/scratch_fixed/model"
log_model_dir = "/content/drive/MyDrive/Colab Notebooks/Capstone/scratch_fixed/result"
test_metric.to_csv(os.path.join(log_model_dir, file_name) + '.csv')

torch.save(model.state_dict(), os.path.join(scratch_model_dir, file_name))

In [7]:
# file_name
factor = 10

In [None]:
acc_lst = []
prob_lst = []

factor_list = [1, 2, 5, 10, 20]
factor = factor_list[3] #  to be modified
topn = round(factor * len(domain_dev_word_lst) / 100)

i = 0
while len(domain_dev_word_lst) >= topn:
  i += 1
  print("\nLoop", i)
  print("domain_dev_word_lst", len(domain_dev_word_lst))

  if i == 1:
    domain_dev_dataset = PosDataset(domain_dev_word_lst, domain_dev_tag_lst)

    domain_dev_iter = data.DataLoader(dataset=domain_dev_dataset,
                                batch_size=8,
                                shuffle=False,
                                num_workers=1,
                                collate_fn=pad)
  else:
    domain_dev_dataset = PosDataset_new(domain_dev_word_lst, domain_dev_tag_lst)

    domain_dev_iter = data.DataLoader(dataset=domain_dev_dataset,
                                batch_size=8,
                                shuffle=True,
                                num_workers=1,
                                collate_fn=pad_new)
  
  initial = True if i==1 else False
  top_words_ids, top_tags_ids, domain_dev_word_lst, domain_dev_tag_lst, new_acc, remain_acc, new_prob, remain_prob = gen_pseudo_data(model, domain_dev_iter, topn, initial)

  # Revert ids to words
  top_words = []
  top_tags = []
  for t in range(len(top_words_ids)):
    word_ids = tokenizer.convert_ids_to_tokens(top_words_ids[t])
    tag_ids = list(map(idx2tag.get, top_tags_ids[t]))
    words = []
    tags = []
    for k, w in enumerate(word_ids):
      if w == '[CLS]':
        pass
      elif w == '[SEP]':
        break
      else:
        words.append(w)
        tags.append(tag_ids[k])
    top_words.append(words)
    top_tags.append(tags)

  new_train_dataset = PosDataset(wsj_train_word_lst+top_words, wsj_train_tag_lst+top_tags)
  new_train_iter = data.DataLoader(dataset=new_train_dataset,
                              batch_size=8,
                              shuffle=True,
                              num_workers=1,
                              collate_fn=pad)

  print("Train from scratch...")
  model = Net(vocab_size=len(tag2idx))
  model.to(device)
  model = nn.DataParallel(model)

  optimizer = optim.Adam(model.parameters(), lr = 0.0001)
  criterion = nn.CrossEntropyLoss(ignore_index=0)

  train(model, new_train_iter, optimizer, criterion)

  domain_precision_value, domain_recall_value, domain_f1_value, domain_acc_value = eval(model, domain_test_iter)

  domain_precision_value_lst.append(domain_precision_value)
  domain_recall_value_lst.append(domain_recall_value)
  domain_f1_value_lst.append(domain_f1_value)
  domain_acc_value_lst.append(domain_acc_value)

  acc_lst.append(new_acc)
  prob_lst.append(new_prob)


In [9]:
factor = 10

In [13]:
domain = "newsgroups"
log_model_dir = "/content/drive/MyDrive/Colab Notebooks/Capstone/scratch_fixed/result"

file_name = f"topn_{factor}%_domain_{domain}"
test_metric = pd.read_csv(os.path.join(log_model_dir, file_name) + '.csv', index_col=0)
test_metric['value'] = test_metric.value.apply(lambda x: float(x[-7:-1]))

fig = px.line(test_metric, x="Loop", y="value", color='metric', markers=True)
fig.show()

In [12]:
!ls '/content/drive/MyDrive/Colab Notebooks/Capstone/scratch_fixed/result/'

 topn_10%_domain_answers.csv	  topn_20%_domain_newsgroups.csv
 topn_10%_domain_emails.csv	  topn_20%_domain_reviews.csv
 topn_10%_domain_newsgroups.csv   topn_20%_domain_weblogs.csv
 topn_10%_domain_reviews.csv	  topn_300_domain_answers.csv
 topn_10%_domain_weblogs.csv	  topn_300_domain_emails.csv
'topn_20\%_domain_answers.csv'	  topn_300_domain_newsgroups.csv
 topn_20%_domain_answers.csv	  topn_300_domain_reviews.csv
 topn_20%_domain_emails.csv	  topn_300_domain_weblogs.csv


In [27]:
domain = "answers"
log_model_dir = "/content/drive/MyDrive/Colab Notebooks/Capstone/scratch_fixed/result"

file_name = f"topn_{factor}%_domain_{domain}"
test_metric = pd.read_csv(os.path.join(log_model_dir, file_name) + '.csv', index_col=0)
test_metric['value'] = test_metric.value.apply(lambda x: float(x[-7:-1]))

fig = px.line(test_metric, x="Loop", y="value", color='metric', markers=True)
fig.show()

In [8]:
domain = "newsgroups"
log_model_dir = "/content/drive/MyDrive/Colab Notebooks/Capstone/scratch_fixed/result"

file_name = f"topn_{factor}%_domain_{domain}"
test_metric = pd.read_csv(os.path.join(log_model_dir, file_name) + '.csv', index_col=0)
test_metric['value'] = test_metric.value.apply(lambda x: float(x[-7:-1]))

fig = px.line(test_metric, x="Loop", y="value", color='metric', markers=True)
fig.show()

In [29]:
domain = "reviews"
log_model_dir = "/content/drive/MyDrive/Colab Notebooks/Capstone/scratch_fixed/result"

file_name = f"topn_{factor}%_domain_{domain}"
test_metric = pd.read_csv(os.path.join(log_model_dir, file_name) + '.csv', index_col=0)
test_metric['value'] = test_metric.value.apply(lambda x: float(x[-7:-1]))

fig = px.line(test_metric, x="Loop", y="value", color='metric', markers=True)
fig.show()

In [39]:
domain = "weblogs"
log_model_dir = "/content/drive/MyDrive/Colab Notebooks/Capstone/scratch_fixed/result"

file_name = f"topn_{factor}%_domain_{domain}"
test_metric = pd.read_csv(os.path.join(log_model_dir, file_name) + '.csv', index_col=0)
test_metric['value'] = test_metric.value.apply(lambda x: float(x[-7:-1]))

fig = px.line(test_metric, x="Loop", y="value", color='metric', markers=True)
fig.show()

In [13]:
fig = px.line(test_metric, x="Loop", y="value", color='metric', markers=True)

fig.update_layout(legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=1
))

fig.show(scale=6, width=500, height=500)

In [1]:
!pip install kaleido


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[K     |████████████████████████████████| 79.9 MB 79 kB/s 
[?25hInstalling collected packages: kaleido
Successfully installed kaleido-0.2.1


In [14]:
import kaleido #required
kaleido.__version__ #0.2.1

import plotly
plotly.__version__ #5.5.0

#now this works:
import plotly.graph_objects as go

# fig = go.Figure()
fig.write_image("/content/drive/MyDrive/Colab Notebooks/Capstone/scratch_fixed/tmp.png", scale=6, width=500, height=500)

In [None]:
file_name_lst = ["answers", "emails", "newsgroups", "reviews", "weblogs"]

In [17]:
test_metric.value[0]

'tensor(0.9122)'

In [20]:
# [tensor(0.9466), tensor(0.9481), tensor(0.9471), tensor(0.9505), tensor(0.9502), tensor(0.9529), tensor(0.9476), tensor(0.9483), tensor(0.9488), tensor(0.9455)]
# [tensor(0.9421), tensor(0.9422), tensor(0.9421), tensor(0.9466), tensor(0.9459), tensor(0.9471), tensor(0.9427), tensor(0.9436), tensor(0.9443), tensor(0.9394)]
# [tensor(0.9392), tensor(0.9397), tensor(0.9402), tensor(0.9441), tensor(0.9432), tensor(0.9450), tensor(0.9403), tensor(0.9416), tensor(0.9427), tensor(0.9363)]
# [tensor(0.9421), tensor(0.9422), tensor(0.9421), tensor(0.9466), tensor(0.9459), tensor(0.9471), tensor(0.9427), tensor(0.9436), tensor(0.9443), tensor(0.9394)]
acc = [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.95652174949646, 1.0, 0.8947368264198303, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.931034505367279, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.9375, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9534883499145508, 1.0, 1.0, 0.9090909361839294, 1.0, 1.0, 0.9444444179534912, 1.0, 1.0, 1.0, 1.0, 0.96875, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9230769276618958, 0.8999999761581421, 1.0, 1.0, 1.0, 1.0, 1.0], [0.9342105388641357, 0.9672130942344666, 0.9473684430122375, 0.9589040875434875, 0.9230769276618958, 0.9726027250289917, 0.9523809552192688, 0.9624999761581421, 0.9411764740943909, 0.9591836929321289, 0.9387755393981934, 0.9577465057373047, 0.9487179517745972, 0.949999988079071, 0.9718309640884399, 0.914893627166748, 0.9523809552192688, 0.9473684430122375, 0.9636363387107849, 0.9411764740943909, 0.9718309640884399, 0.9777777791023254, 0.9710144996643066, 0.9714285731315613, 0.9558823704719543, 0.7692307829856873, 0.9577465057373047, 0.9599999785423279, 0.957446813583374, 0.9365079402923584, 0.9677419066429138, 0.95652174949646, 0.9722222089767456, 0.9142857193946838, 0.9523809552192688, 0.9583333134651184, 0.9487179517745972, 0.9607843160629272, 0.9487179517745972, 0.9750000238418579, 0.9591836929321289, 0.9672130942344666, 0.9642857313156128, 0.9473684430122375, 0.9523809552192688, 0.9523809552192688, 0.9428571462631226, 0.9591836929321289, 0.9599999785423279, 0.9428571462631226, 0.9607843160629272, 0.9399999976158142, 0.9375, 0.9583333134651184, 0.9056603908538818, 0.9487179517745972, 0.9275362491607666, 0.9090909361839294, 0.930232584476471, 0.9636363387107849, 0.9399999976158142, 0.9230769276618958, 0.9607843160629272, 0.8870967626571655, 0.9523809552192688, 0.9583333134651184, 0.9433962106704712, 0.9347826242446899, 0.970588207244873, 0.9487179517745972, 0.9523809552192688, 0.925000011920929, 0.9473684430122375, 0.9344262480735779, 0.9722222089767456, 0.9473684430122375, 0.9594594836235046, 0.9433962106704712, 0.9285714030265808, 0.95652174949646, 0.9534883499145508, 0.9444444179534912, 0.9420289993286133, 0.95652174949646, 0.9682539701461792, 0.9487179517745972, 0.925000011920929, 0.9230769276618958, 0.9402984976768494, 0.9701492786407471, 0.9318181872367859, 0.9454545378684998, 0.9047619104385376, 0.9130434989929199, 0.9111111164093018, 0.930232584476471, 0.9090909361839294, 0.95652174949646, 0.9487179517745972, 0.9285714030265808, 0.9682539701461792, 0.970588207244873], [0.9583333134651184, 0.9285714030265808, 0.9594594836235046, 0.9305555820465088, 0.970588207244873, 0.9666666388511658, 0.9726027250289917, 0.9552238583564758, 0.9759036302566528, 0.9819819927215576, 0.9733333587646484, 0.9642857313156128, 0.9777777791023254, 0.9193548560142517, 0.976190447807312, 0.9642857313156128, 0.9722222089767456, 0.9718309640884399, 0.9677419066429138, 0.9594594836235046, 0.9666666388511658, 0.9759036302566528, 0.9733333587646484, 0.9518072009086609, 0.9111111164093018, 0.9719626307487488, 0.9583333134651184, 0.9714285731315613, 0.954954981803894, 0.8777777552604675, 0.9345794320106506, 0.9583333134651184, 0.9777777791023254, 0.9583333134651184, 0.9726027250289917, 0.9736841917037964, 0.9729729890823364, 0.9722222089767456, 0.9777777791023254, 0.9819819927215576, 0.9622641801834106, 0.9714285731315613, 0.9722222089767456, 0.9583333134651184, 0.9819819927215576, 0.9558823704719543, 0.9729729890823364, 0.970588207244873, 0.9759036302566528, 0.9714285731315613, 0.9264705777168274, 0.9677419066429138, 0.9638554453849792, 0.9750000238418579, 0.9736841917037964, 0.9722222089767456, 0.9813084006309509, 0.9743589758872986, 0.9555555582046509, 0.9599999785423279, 0.9444444179534912, 0.9819819927215576, 0.9682539701461792, 0.9615384340286255, 0.970588207244873, 0.976190447807312, 0.9555555582046509, 0.9719626307487488, 0.9508196711540222, 0.9813084006309509, 0.9719626307487488, 0.9594594836235046, 0.9666666388511658, 0.9714285731315613, 0.9428571462631226, 0.9639639854431152, 0.9719626307487488, 0.9066666960716248, 0.9714285731315613, 0.9404761791229248, 0.9552238583564758, 0.9552238583564758, 0.9666666388511658, 0.9555555582046509, 0.9777777791023254, 0.9404761791229248, 0.9677419066429138, 0.976190447807312, 0.9577465057373047, 0.9583333134651184, 0.9722222089767456, 0.9813084006309509, 0.9722222089767456, 0.9714285731315613, 0.9777777791023254, 0.9743589758872986, 0.9571428298950195, 0.9523809552192688, 0.9729729890823364, 0.9444444179534912, 0.9714285731315613, 0.9743589758872986], [0.9813084006309509, 0.9666666388511658, 0.9599999785423279, 0.9813084006309509, 0.9729729890823364, 0.9729729890823364, 0.9639639854431152, 0.9639639854431152, 0.9375, 0.9626168012619019, 0.976190447807312, 0.5777778029441833, 0.9729729890823364, 0.9439252614974976, 0.9626168012619019, 0.9729729890823364, 0.9729729890823364, 0.9523809552192688, 0.9666666388511658, 0.9639639854431152, 0.9813084006309509, 0.9719626307487488, 0.9666666388511658, 0.9813084006309509, 0.9444444179534912, 0.9819819927215576, 0.9729729890823364, 0.9719626307487488, 0.9639639854431152, 0.773809552192688, 0.954954981803894, 0.9439252614974976, 0.9719626307487488, 0.9666666388511658, 0.9719626307487488, 0.9189189076423645, 0.9624999761581421, 0.9729729890823364, 0.954954981803894, 0.9819819927215576, 0.9555555582046509, 0.954954981803894, 0.9719626307487488, 0.954954981803894, 0.9729729890823364, 0.954954981803894, 0.9333333373069763, 0.9719626307487488, 0.9626168012619019, 0.9333333373069763, 0.9397590160369873, 0.9729729890823364, 0.9626168012619019, 0.9369369149208069, 0.9813084006309509, 0.9626168012619019, 0.9759036302566528, 0.9252336621284485, 0.9666666388511658, 0.9729729890823364, 0.9819819927215576, 0.9719626307487488, 0.9719626307487488, 0.9819819927215576, 0.9819819927215576, 0.954954981803894, 0.9605262875556946, 0.9759036302566528, 0.9777777791023254, 0.9158878326416016, 0.9777777791023254, 0.9729729890823364, 0.9666666388511658, 0.9759036302566528, 0.9639639854431152, 0.9555555582046509, 0.9813084006309509, 0.9666666388511658, 0.9639639854431152, 0.9666666388511658, 0.9345794320106506, 0.9813084006309509, 0.9599999785423279, 0.9719626307487488, 0.9719626307487488, 0.9819819927215576, 0.9487179517745972, 0.9439252614974976, 0.9819819927215576, 0.9666666388511658, 0.9626168012619019, 0.9813084006309509, 0.9639639854431152, 0.954954981803894, 0.9819819927215576, 0.976190447807312, 0.9729729890823364, 0.9819819927215576, 0.9444444179534912, 0.9813084006309509, 0.9626168012619019, 0.9555555582046509], [0.9909909963607788, 0.9729729890823364, 0.9909909963607788, 0.9639639854431152, 0.9459459185600281, 0.9639639854431152, 0.9369369149208069, 0.9909909963607788, 0.9639639854431152, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.954954981803894, 0.9909909963607788, 1.0, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.954954981803894, 0.9729729890823364, 0.9729729890823364, 0.9459459185600281, 0.9099099040031433, 0.9729729890823364, 1.0, 0.9459459185600281, 0.9639639854431152, 0.9819819927215576, 0.9909909963607788, 0.9729729890823364, 0.9459459185600281, 0.9639639854431152, 1.0, 0.9909909963607788, 0.9819819927215576, 0.9639639854431152, 0.9909909963607788, 0.9909909963607788, 1.0, 0.9909909963607788, 0.9819819927215576, 0.9909909963607788, 0.9639639854431152, 0.9909909963607788, 1.0, 0.9819819927215576, 0.9459459185600281, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9909909963607788, 0.9279279112815857, 0.9819819927215576, 0.9909909963607788, 0.9909909963607788, 0.954954981803894, 0.9729729890823364, 0.9279279112815857, 0.9909909963607788, 0.9909909963607788, 0.954954981803894, 0.9729729890823364, 0.9009009003639221, 0.9279279112815857, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9369369149208069, 0.8918918967247009, 0.9819819927215576, 0.9729729890823364, 1.0, 0.9909909963607788, 0.9729729890823364, 0.9819819927215576, 0.9909909963607788, 0.9099099040031433, 0.9909909963607788, 1.0, 0.9369369149208069, 0.9819819927215576, 0.9819819927215576, 0.9909909963607788, 0.9819819927215576, 0.9279279112815857, 1.0, 0.9729729890823364, 1.0, 0.9189189076423645, 1.0, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9459459185600281, 1.0, 0.9909909963607788, 0.9729729890823364, 0.9459459185600281, 0.9819819927215576, 0.9909909963607788], [0.9279279112815857, 0.9099099040031433, 0.7747747898101807, 0.8288288116455078, 0.7387387156486511, 0.4864864945411682, 0.6486486196517944, 0.5045045018196106, 0.630630612373352, 0.6486486196517944, 0.7837837934494019, 0.5585585832595825, 0.7657657861709595, 0.6216216087341309, 0.7657657861709595, 0.522522509098053, 0.5675675868988037, 0.477477490901947, 0.4954954981803894, 0.6126126050949097, 0.5045045018196106, 0.6216216087341309, 0.5675675868988037, 0.5765765905380249, 0.5585585832595825, 0.6126126050949097, 0.5315315127372742, 0.45045045018196106, 0.4864864945411682, 0.44144144654273987, 0.342342346906662, 0.342342346906662, 0.44144144654273987, 0.5495495200157166, 0.46846845746040344, 0.6216216087341309, 0.45045045018196106, 0.4324324429035187, 0.46846845746040344, 0.37837839126586914, 0.4054054021835327, 0.3963963985443115, 0.477477490901947, 0.4054054021835327, 0.3333333432674408, 0.5135135054588318, 0.38738739490509033, 0.38738739490509033, 0.5585585832595825, 0.3243243098258972, 0.4234234094619751, 0.4864864945411682, 0.37837839126586914, 0.3243243098258972, 0.37837839126586914, 0.36936935782432556, 0.315315306186676, 0.3513513505458832, 0.3513513505458832, 0.36036035418510437, 0.38738739490509033, 0.28828829526901245, 0.45945945382118225, 0.4864864945411682, 0.3243243098258972, 0.3513513505458832, 0.3513513505458832, 0.3333333432674408, 0.342342346906662, 0.29729729890823364, 0.3963963985443115, 0.30630630254745483, 0.29729729890823364, 0.38738739490509033, 0.3243243098258972, 0.4144144058227539, 0.4054054021835327, 0.36936935782432556, 0.342342346906662, 0.3513513505458832, 0.36036035418510437, 0.3513513505458832, 0.4144144058227539, 0.36036035418510437, 0.315315306186676, 0.4054054021835327, 0.342342346906662, 0.3243243098258972, 0.27927929162979126, 0.3513513505458832, 0.38738739490509033, 0.4324324429035187, 0.315315306186676, 0.2702702581882477, 0.36036035418510437, 0.38738739490509033, 0.3333333432674408, 0.315315306186676, 0.3513513505458832, 0.3963963985443115, 0.3243243098258972, 0.36036035418510437], [0.09909909963607788, 0.045045044273138046, 0.06306306272745132, 0.036036036908626556, 0.018018018454313278, 0.027027027681469917, 0.1621621549129486, 0.036036036908626556, 0.171171173453331, 0.2522522509098053, 0.06306306272745132, 0.22522522509098053, 0.10810811072587967, 0.054054055362939835, 0.06306306272745132, 0.11711711436510086, 0.13513512909412384, 0.09009008854627609, 0.21621622145175934, 0.15315315127372742, 0.2432432472705841, 0.0810810774564743, 0.0810810774564743, 0.171171173453331, 0.28828829526901245, 0.21621622145175934, 0.18018017709255219, 0.2522522509098053, 0.09909909963607788, 0.027027027681469917, 0.09009008854627609, 0.22522522509098053, 0.11711711436510086, 0.12612612545490265, 0.10810811072587967, 0.22522522509098053, 0.09909909963607788, 0.09909909963607788, 0.18918919563293457, 0.15315315127372742, 0.22522522509098053, 0.22522522509098053, 0.21621622145175934, 0.12612612545490265, 0.15315315127372742, 0.09009008854627609, 0.07207207381725311, 0.15315315127372742, 0.0810810774564743, 0.21621622145175934, 0.12612612545490265, 0.22522522509098053, 0.09909909963607788, 0.28828829526901245, 0.0810810774564743, 0.09909909963607788, 0.20720720291137695, 0.11711711436510086, 0.21621622145175934, 0.11711711436510086, 0.19819819927215576, 0.12612612545490265, 0.09009008854627609, 0.22522522509098053, 0.054054055362939835, 0.30630630254745483, 0.21621622145175934, 0.171171173453331, 0.13513512909412384, 0.045045044273138046, 0.15315315127372742, 0.1621621549129486, 0.12612612545490265, 0.2702702581882477, 0.20720720291137695, 0.2702702581882477, 0.11711711436510086, 0.0810810774564743, 0.15315315127372742, 0.20720720291137695, 0.23423422873020172, 0.10810811072587967, 0.09009008854627609, 0.20720720291137695, 0.09009008854627609, 0.171171173453331, 0.09009008854627609, 0.12612612545490265, 0.09909909963607788, 0.0810810774564743, 0.23423422873020172, 0.23423422873020172, 0.09909909963607788, 0.15315315127372742, 0.2522522509098053, 0.18918919563293457, 0.1621621549129486, 0.2432432472705841, 0.15315315127372742, 0.15315315127372742, 0.045045044273138046, 0.18918919563293457], [0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9639639854431152, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.954954981803894, 0.954954981803894, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9729729890823364, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9909909963607788, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9639639854431152, 0.9639639854431152, 0.9729729890823364, 0.9729729890823364, 0.954954981803894, 0.9639639854431152, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9639639854431152, 0.9819819927215576, 0.9819819927215576, 0.9459459185600281, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9639639854431152, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9459459185600281, 0.9639639854431152, 0.954954981803894, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9459459185600281, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.954954981803894, 0.9459459185600281, 0.9729729890823364, 0.9819819927215576, 0.9729729890823364, 0.9819819927215576, 0.954954981803894, 0.9819819927215576, 0.9819819927215576], [0.9909909963607788, 0.9909909963607788, 0.9909909963607788, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9909909963607788, 0.9909909963607788, 0.9909909963607788, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9639639854431152, 0.9819819927215576, 0.9909909963607788, 0.9909909963607788, 0.9819819927215576, 0.9819819927215576, 0.9909909963607788, 0.9909909963607788, 0.9909909963607788, 0.9819819927215576, 0.9909909963607788, 0.9909909963607788, 0.9819819927215576, 0.9909909963607788, 0.9909909963607788, 0.9639639854431152, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9639639854431152, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9729729890823364, 0.9729729890823364, 0.9819819927215576, 0.9909909963607788, 0.9819819927215576, 0.954954981803894, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.954954981803894, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9909909963607788, 0.9819819927215576, 0.9639639854431152, 0.9819819927215576, 0.9909909963607788, 0.9729729890823364, 0.954954981803894, 0.9819819927215576, 0.9819819927215576, 0.9819819927215576, 0.9909909963607788, 0.9819819927215576, 0.9909909963607788, 0.954954981803894, 0.9819819927215576, 0.9639639854431152, 0.9819819927215576, 0.9819819927215576, 0.9639639854431152, 0.9819819927215576, 0.9729729890823364, 0.9909909963607788, 0.9909909963607788, 0.9909909963607788, 0.954954981803894, 0.9639639854431152, 0.9819819927215576, 0.9729729890823364, 0.9909909963607788, 0.9909909963607788, 0.9729729890823364, 0.9909909963607788, 0.9729729890823364, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9729729890823364, 0.9729729890823364, 0.9729729890823364, 0.9819819927215576, 0.9819819927215576, 0.9909909963607788, 0.9729729890823364, 0.9909909963607788, 1.0, 0.9729729890823364, 0.9909909963607788, 0.9729729890823364]]
prob_list = [[0.999798059463501, 0.9997609257698059, 0.9997243881225586, 0.9996716380119324, 0.9996636509895325, 0.9996532201766968, 0.9995818138122559, 0.9995136857032776, 0.9994621276855469, 0.9994298815727234, 0.9994193911552429, 0.9992927312850952, 0.9991897344589233, 0.999162495136261, 0.9990848898887634, 0.9990408420562744, 0.9990352392196655, 0.9990233182907104, 0.9989814758300781, 0.9989410042762756, 0.9989268779754639, 0.9989258050918579, 0.9989162683486938, 0.9988672733306885, 0.9988577961921692, 0.9988387227058411, 0.9987951517105103, 0.9987403750419617, 0.9985683560371399, 0.9985527992248535, 0.9985092878341675, 0.9984806180000305, 0.9984214305877686, 0.9983848929405212, 0.9983698725700378, 0.9983627200126648, 0.9983537197113037, 0.9983531832695007, 0.9982770681381226, 0.9982223510742188, 0.9982121586799622, 0.9981723427772522, 0.9981693625450134, 0.9981638789176941, 0.9981609582901001, 0.9981470108032227, 0.998144268989563, 0.9981099963188171, 0.9980669617652893, 0.9980605840682983, 0.9980602264404297, 0.9980195760726929, 0.9980186820030212, 0.9980130791664124, 0.9979899525642395, 0.997988760471344, 0.997943103313446, 0.9978846311569214, 0.9978086352348328, 0.9978045225143433, 0.9977808594703674, 0.9977502822875977, 0.9977333545684814, 0.9976773858070374, 0.9976698756217957, 0.9976546168327332, 0.9976219534873962, 0.9976049065589905, 0.9975472688674927, 0.9975295066833496, 0.9974830150604248, 0.9974546432495117, 0.9974456429481506, 0.9974451065063477, 0.997439980506897, 0.9974192976951599, 0.997401237487793, 0.9973815083503723, 0.9973633885383606, 0.9973442554473877, 0.9972391128540039, 0.9971839189529419, 0.9971553087234497, 0.9971469640731812, 0.9971455931663513, 0.9971393346786499, 0.9970966577529907, 0.9970418214797974, 0.9970322251319885, 0.9969974756240845, 0.9969155788421631, 0.9968834519386292, 0.9968796372413635, 0.9968288540840149, 0.996789276599884, 0.9966727495193481, 0.9966233968734741, 0.9966047406196594, 0.996592104434967, 0.9965866804122925, 0.9965119957923889, 0.9965052008628845], [0.9817234873771667, 0.9785272479057312, 0.9776372313499451, 0.9764548540115356, 0.9757897257804871, 0.9756594896316528, 0.9752596020698547, 0.9750471115112305, 0.9744119048118591, 0.9741156697273254, 0.9739365577697754, 0.9731894731521606, 0.9729788303375244, 0.9728564023971558, 0.9725017547607422, 0.9723671674728394, 0.9723227620124817, 0.9722226858139038, 0.9716639518737793, 0.9716222286224365, 0.9715116024017334, 0.9712647199630737, 0.9710476398468018, 0.9707461595535278, 0.9707195162773132, 0.9705784916877747, 0.9702393412590027, 0.9698932766914368, 0.9696681499481201, 0.969649076461792, 0.9693778157234192, 0.9691499471664429, 0.9688546657562256, 0.9686205387115479, 0.9684120416641235, 0.9681845903396606, 0.9675257802009583, 0.9672830700874329, 0.967068612575531, 0.9667770266532898, 0.9667015671730042, 0.9666045308113098, 0.9664918780326843, 0.9664787650108337, 0.966399073600769, 0.9663869142532349, 0.9660936594009399, 0.9659124612808228, 0.965833306312561, 0.9658108353614807, 0.9657067656517029, 0.9656558036804199, 0.9652805328369141, 0.9652679562568665, 0.9649048447608948, 0.9648956060409546, 0.964782178401947, 0.9644887447357178, 0.9644779562950134, 0.9643340110778809, 0.964260458946228, 0.9642398357391357, 0.9641810655593872, 0.9641528129577637, 0.9641153216362, 0.9637807011604309, 0.9636768102645874, 0.9636565446853638, 0.9636304378509521, 0.9635757803916931, 0.9635730981826782, 0.963456928730011, 0.9633830189704895, 0.9632798433303833, 0.9632640480995178, 0.9632226824760437, 0.9632042646408081, 0.9631358981132507, 0.9629944562911987, 0.9629855155944824, 0.962975263595581, 0.9629319310188293, 0.9628963470458984, 0.9627582430839539, 0.9624484181404114, 0.9623445272445679, 0.9622284173965454, 0.962142288684845, 0.9617661237716675, 0.9616733193397522, 0.961666464805603, 0.9614667892456055, 0.9614167809486389, 0.961249828338623, 0.9612241387367249, 0.9611873030662537, 0.9611416459083557, 0.9611138701438904, 0.9610026478767395, 0.9609654545783997, 0.9608424305915833, 0.9607900977134705], [0.9796269536018372, 0.9779704213142395, 0.9765051007270813, 0.9752669930458069, 0.9749463200569153, 0.9747986793518066, 0.9729233980178833, 0.9728026986122131, 0.9721291065216064, 0.9706289768218994, 0.9704241752624512, 0.970364511013031, 0.9701799154281616, 0.9699467420578003, 0.9696168303489685, 0.969573974609375, 0.969325602054596, 0.9691293239593506, 0.9690230488777161, 0.9689098000526428, 0.968425989151001, 0.9683214426040649, 0.9681045413017273, 0.9680607914924622, 0.9678324460983276, 0.9675041437149048, 0.9674469828605652, 0.9670031070709229, 0.9664599299430847, 0.9653291702270508, 0.9652350544929504, 0.9650313258171082, 0.9645279049873352, 0.9645137786865234, 0.9644918441772461, 0.9644035696983337, 0.9642328023910522, 0.9641045331954956, 0.9640594720840454, 0.9640561938285828, 0.9640471935272217, 0.9640272855758667, 0.9639513492584229, 0.9638761878013611, 0.9638166427612305, 0.9636589288711548, 0.9634768962860107, 0.9632647037506104, 0.9631083607673645, 0.9630898833274841, 0.9630668759346008, 0.9629735350608826, 0.9627017974853516, 0.9626386761665344, 0.9625383615493774, 0.9625263214111328, 0.9624972939491272, 0.9623945951461792, 0.9622145891189575, 0.9619469046592712, 0.9619041085243225, 0.961887776851654, 0.9617716073989868, 0.9616421461105347, 0.9616131782531738, 0.9616064429283142, 0.9613900780677795, 0.9612351059913635, 0.9611502885818481, 0.961097240447998, 0.960878312587738, 0.9608434438705444, 0.9605333805084229, 0.9604579210281372, 0.9604002833366394, 0.9603492617607117, 0.960338830947876, 0.9601511359214783, 0.9600471258163452, 0.9598530530929565, 0.9598486423492432, 0.9597331881523132, 0.9597138166427612, 0.9595560431480408, 0.9595353007316589, 0.9594456553459167, 0.9594177603721619, 0.959350049495697, 0.9593121409416199, 0.9591972827911377, 0.9590803980827332, 0.9590479731559753, 0.9590256810188293, 0.9589102864265442, 0.9587776064872742, 0.9587411880493164, 0.9586794376373291, 0.9586426019668579, 0.9585512280464172, 0.9585363864898682, 0.9585250616073608, 0.9584994316101074], [0.982528567314148, 0.9816603660583496, 0.9746264815330505, 0.9739543199539185, 0.9737696051597595, 0.9732950329780579, 0.9727626442909241, 0.9703720808029175, 0.9697786569595337, 0.9694781303405762, 0.9688504338264465, 0.9683973789215088, 0.9683190584182739, 0.9676313400268555, 0.9675401449203491, 0.967525064945221, 0.9673555493354797, 0.966675341129303, 0.9666330218315125, 0.9666058421134949, 0.966178834438324, 0.9659404158592224, 0.9658769369125366, 0.9657328724861145, 0.9653806686401367, 0.9651960730552673, 0.9650385975837708, 0.9646810293197632, 0.9646472930908203, 0.9644439220428467, 0.9641814231872559, 0.9634606242179871, 0.9634435176849365, 0.9633919596672058, 0.9632967710494995, 0.9631830453872681, 0.9630329012870789, 0.9625131487846375, 0.9618176817893982, 0.9616665244102478, 0.961598813533783, 0.9613313674926758, 0.9610714316368103, 0.9609618186950684, 0.9609208703041077, 0.9607269763946533, 0.9607110023498535, 0.9606322646141052, 0.9604658484458923, 0.960253894329071, 0.9596662521362305, 0.9594911932945251, 0.9594327211380005, 0.9592575430870056, 0.9591984152793884, 0.9591659307479858, 0.9591086506843567, 0.9589337706565857, 0.958751916885376, 0.9587112069129944, 0.9586741328239441, 0.9585857391357422, 0.9583753943443298, 0.9582207202911377, 0.9579774141311646, 0.9579231142997742, 0.9578258991241455, 0.9577410817146301, 0.9575625658035278, 0.9573559165000916, 0.9573303461074829, 0.9568030834197998, 0.956739068031311, 0.9565786123275757, 0.956463098526001, 0.9561796188354492, 0.9561334848403931, 0.9559201002120972, 0.955864429473877, 0.9557777643203735, 0.9556957483291626, 0.9556658267974854, 0.9555184841156006, 0.9552361369132996, 0.9550356268882751, 0.9550142288208008, 0.9548836946487427, 0.9546293020248413, 0.9545192122459412, 0.9545065760612488, 0.9544780254364014, 0.9543478488922119, 0.95427405834198, 0.9542123675346375, 0.954182505607605, 0.9540306925773621, 0.9538965821266174, 0.9538132548332214, 0.9537861943244934, 0.9536778926849365, 0.953654408454895, 0.953607439994812], [0.9962618947029114, 0.9962359666824341, 0.9940988421440125, 0.9940952658653259, 0.9939370155334473, 0.9933323860168457, 0.993274450302124, 0.9929867386817932, 0.9929836392402649, 0.9926305413246155, 0.9924961924552917, 0.9921882152557373, 0.9918815493583679, 0.9918420314788818, 0.9918159246444702, 0.9917689561843872, 0.9917600750923157, 0.991748571395874, 0.9916809797286987, 0.9916248321533203, 0.9915640950202942, 0.9915551543235779, 0.991539716720581, 0.9915123581886292, 0.9915103912353516, 0.9914835095405579, 0.9914689064025879, 0.9914180040359497, 0.9914093613624573, 0.991401195526123, 0.9913724660873413, 0.9913482069969177, 0.9913210272789001, 0.9912845492362976, 0.9912844896316528, 0.9912833571434021, 0.991249680519104, 0.9912381768226624, 0.9911937713623047, 0.9911912083625793, 0.9911167621612549, 0.9910718202590942, 0.9910430312156677, 0.9910363554954529, 0.9910348653793335, 0.9910162687301636, 0.9909998774528503, 0.990997314453125, 0.9909651279449463, 0.9909480810165405, 0.9909345507621765, 0.9909231662750244, 0.9909182786941528, 0.9909132122993469, 0.9908819794654846, 0.9908693432807922, 0.9908515810966492, 0.9908245205879211, 0.9908155798912048, 0.9908043146133423, 0.9907421469688416, 0.9907311797142029, 0.9907229542732239, 0.9907078146934509, 0.9906930923461914, 0.9906731843948364, 0.9906308054924011, 0.9906076788902283, 0.9905974864959717, 0.9904988408088684, 0.9904861450195312, 0.9904653429985046, 0.9904295206069946, 0.9904133081436157, 0.9903989434242249, 0.9903278350830078, 0.9902737140655518, 0.9902559518814087, 0.9902554750442505, 0.9902251362800598, 0.9902238845825195, 0.990131139755249, 0.9900739789009094, 0.9900566935539246, 0.9900454878807068, 0.9900386333465576, 0.9899905920028687, 0.9899823069572449, 0.9899600744247437, 0.9899414777755737, 0.9899335503578186, 0.989916980266571, 0.9899077415466309, 0.989899218082428, 0.9898672699928284, 0.9898489117622375, 0.9898424744606018, 0.9897879958152771, 0.9897738695144653, 0.9897618293762207, 0.9897324442863464, 0.9896923303604126], [0.9450442790985107, 0.8900070786476135, 0.8794874548912048, 0.8765034079551697, 0.8563343286514282, 0.8449638485908508, 0.8381963968276978, 0.8345248699188232, 0.8335363864898682, 0.8257083296775818, 0.8199054598808289, 0.8173989057540894, 0.8154930472373962, 0.8099790215492249, 0.8079418540000916, 0.8077139854431152, 0.8076819181442261, 0.8033620715141296, 0.7989152669906616, 0.7922243475914001, 0.7913991808891296, 0.7883793711662292, 0.7875060439109802, 0.7834787368774414, 0.7808229327201843, 0.7797859907150269, 0.7778797149658203, 0.776918351650238, 0.7697858810424805, 0.7682508826255798, 0.7681737542152405, 0.7681737542152405, 0.7671624422073364, 0.764695942401886, 0.763361394405365, 0.7623496055603027, 0.7622898817062378, 0.7612999081611633, 0.7610208988189697, 0.7594681978225708, 0.7554068565368652, 0.755370020866394, 0.7548799514770508, 0.751437783241272, 0.748779296875, 0.7487309575080872, 0.7459191679954529, 0.745694637298584, 0.7422711253166199, 0.7418997287750244, 0.7416740655899048, 0.7397549152374268, 0.7369393110275269, 0.7366416454315186, 0.7343281507492065, 0.7333763241767883, 0.7332356572151184, 0.7331897616386414, 0.7327280640602112, 0.7314372658729553, 0.7310234308242798, 0.7309238910675049, 0.7303101420402527, 0.7302184104919434, 0.7293860912322998, 0.7264994978904724, 0.7264994978904724, 0.7257937788963318, 0.7249041199684143, 0.7248242497444153, 0.724549412727356, 0.7241301536560059, 0.7236127257347107, 0.7235265374183655, 0.722653865814209, 0.7225748896598816, 0.7213932871818542, 0.7210787534713745, 0.720945417881012, 0.7199737429618835, 0.7193097472190857, 0.7188809514045715, 0.7176252007484436, 0.7174767851829529, 0.717387855052948, 0.7173866033554077, 0.7170433402061462, 0.7163699865341187, 0.7156012654304504, 0.7155174612998962, 0.7150073647499084, 0.7148811221122742, 0.7144075632095337, 0.714149534702301, 0.713483989238739, 0.7128716707229614, 0.7126681208610535, 0.7116051316261292, 0.7115867137908936, 0.7115576267242432, 0.711024820804596, 0.7103127241134644], [0.9914625287055969, 0.9913865923881531, 0.9910546541213989, 0.9890826940536499, 0.9881108999252319, 0.9875422716140747, 0.9872554540634155, 0.9870415329933167, 0.9870398640632629, 0.9868173599243164, 0.9865642786026001, 0.9865396618843079, 0.9865391254425049, 0.9863222241401672, 0.9862241148948669, 0.9859433770179749, 0.9858052134513855, 0.9857724905014038, 0.9857711791992188, 0.9857069253921509, 0.9855886697769165, 0.9855689406394958, 0.9855689406394958, 0.9855484366416931, 0.9854944348335266, 0.9854599833488464, 0.9854305982589722, 0.985424816608429, 0.9853605628013611, 0.9850888848304749, 0.9850178360939026, 0.9849340915679932, 0.9848824739456177, 0.9848802089691162, 0.9848375916481018, 0.9848281145095825, 0.9847772121429443, 0.9847543835639954, 0.9847471117973328, 0.98466956615448, 0.9845507740974426, 0.9845468401908875, 0.9845350384712219, 0.9845187664031982, 0.9844965934753418, 0.9844658374786377, 0.9843281507492065, 0.9843224883079529, 0.9843153953552246, 0.9842690229415894, 0.9842603802680969, 0.9842472672462463, 0.9842242002487183, 0.9841914772987366, 0.9841830730438232, 0.9839462637901306, 0.9839314818382263, 0.9839127063751221, 0.9839079976081848, 0.9837791323661804, 0.9837789535522461, 0.9837646484375, 0.9837576150894165, 0.9837507009506226, 0.9837449789047241, 0.9836942553520203, 0.9836573004722595, 0.9836475253105164, 0.9836066961288452, 0.9836016297340393, 0.9835909605026245, 0.9835889935493469, 0.9835307002067566, 0.9834864735603333, 0.9834781289100647, 0.9834458231925964, 0.9834340214729309, 0.9833908677101135, 0.9833354949951172, 0.983235239982605, 0.9832062125205994, 0.9832049012184143, 0.9831846952438354, 0.9831750392913818, 0.9831539988517761, 0.9831278324127197, 0.983002245426178, 0.9829297661781311, 0.9828432202339172, 0.9827500581741333, 0.9827188849449158, 0.9827188849449158, 0.9826750755310059, 0.9826740026473999, 0.9826594591140747, 0.9826410412788391, 0.9826366901397705, 0.9826234579086304, 0.9825868606567383, 0.9825079441070557, 0.982452392578125, 0.9824007153511047], [0.9933387637138367, 0.9932316541671753, 0.9923577308654785, 0.9922049641609192, 0.991722047328949, 0.99170982837677, 0.9915115833282471, 0.9913317561149597, 0.9910642504692078, 0.9908437728881836, 0.9905516505241394, 0.9904618859291077, 0.9900728464126587, 0.9900256991386414, 0.9898050427436829, 0.9897936582565308, 0.9896578192710876, 0.9896389842033386, 0.9896085858345032, 0.9894849061965942, 0.9893649220466614, 0.9893649220466614, 0.9893522262573242, 0.9893458485603333, 0.9892410635948181, 0.9891753196716309, 0.9891207814216614, 0.9891207814216614, 0.9890836477279663, 0.9890491366386414, 0.9890003204345703, 0.9888866543769836, 0.9888659119606018, 0.9888610243797302, 0.9888567328453064, 0.9887605905532837, 0.9886168241500854, 0.9885454773902893, 0.9885188341140747, 0.9884737133979797, 0.9884617328643799, 0.9884608387947083, 0.9883822798728943, 0.9883486032485962, 0.9880001544952393, 0.9879890084266663, 0.9879744648933411, 0.9878502488136292, 0.987819492816925, 0.9876888394355774, 0.9876157641410828, 0.9875303506851196, 0.9875005483627319, 0.9874879121780396, 0.9873949885368347, 0.9873784780502319, 0.9873462319374084, 0.9873121380805969, 0.9872682690620422, 0.9871925115585327, 0.9871401190757751, 0.9871379137039185, 0.9871112704277039, 0.9870783090591431, 0.9870628118515015, 0.9868687987327576, 0.9868627190589905, 0.9868147969245911, 0.9866924285888672, 0.9865671396255493, 0.9865587949752808, 0.9864859580993652, 0.9864835739135742, 0.9864541888237, 0.9864044189453125, 0.9864006638526917, 0.9862088561058044, 0.9861986041069031, 0.9860697984695435, 0.9860434532165527, 0.9860214591026306, 0.9860170483589172, 0.9859792590141296, 0.9859781265258789, 0.9859337210655212, 0.9859290719032288, 0.9859133362770081, 0.9858909845352173, 0.9858599305152893, 0.9858450889587402, 0.9857838749885559, 0.9857778549194336, 0.9857764840126038, 0.985692024230957, 0.985539972782135, 0.9855263233184814, 0.9854715466499329, 0.9853859543800354, 0.9853073954582214, 0.9852293729782104, 0.9852196574211121, 0.9851891994476318], [0.9979798793792725, 0.9979432225227356, 0.9975451827049255, 0.9973692893981934, 0.9971492290496826, 0.9971070289611816, 0.9966923594474792, 0.9965962171554565, 0.9964686036109924, 0.9963430166244507, 0.9963186383247375, 0.9957948327064514, 0.9957292675971985, 0.9956396818161011, 0.9956262111663818, 0.9956238865852356, 0.9955843687057495, 0.9954198598861694, 0.9953758120536804, 0.9953103065490723, 0.9950469136238098, 0.9950356483459473, 0.9949610233306885, 0.9949295520782471, 0.9946702718734741, 0.9946547746658325, 0.9944997429847717, 0.9942606091499329, 0.9940365552902222, 0.993988037109375, 0.993751585483551, 0.993404746055603, 0.9933041930198669, 0.993100643157959, 0.9930046200752258, 0.9929131269454956, 0.9925493001937866, 0.9925011396408081, 0.9924747943878174, 0.9924297332763672, 0.9924280643463135, 0.9920867085456848, 0.9920803904533386, 0.9919462203979492, 0.991753876209259, 0.9917145371437073, 0.9916725754737854, 0.9916725754737854, 0.991528332233429, 0.9915065169334412, 0.9914698004722595, 0.9914343357086182, 0.9913005828857422, 0.99126136302948, 0.9912552237510681, 0.9910513162612915, 0.9909546375274658, 0.9908639788627625, 0.9908109903335571, 0.9906459450721741, 0.9906203746795654, 0.9905369877815247, 0.9904628396034241, 0.9904350638389587, 0.9903453588485718, 0.9902966022491455, 0.9900204539299011, 0.9898535013198853, 0.989772379398346, 0.9897451400756836, 0.9896813035011292, 0.989677906036377, 0.9896506667137146, 0.9896453022956848, 0.9896261096000671, 0.9895188808441162, 0.9895167946815491, 0.9895046949386597, 0.9894501566886902, 0.989448070526123, 0.9893593788146973, 0.9892783164978027, 0.9891482591629028, 0.9890274405479431, 0.9889963269233704, 0.9889777898788452, 0.9888536334037781, 0.9886375665664673, 0.9884473085403442, 0.9884059429168701, 0.9879454374313354, 0.9878860116004944, 0.9877198338508606, 0.9875537753105164, 0.9874809384346008, 0.9874266386032104, 0.9873436093330383, 0.9873360395431519, 0.9873173236846924, 0.9872413277626038, 0.9872055053710938, 0.987018346786499]]

In [24]:
acc_df = pd.DataFrame(acc)
acc_df = pd.DataFrame({'mean': acc_df.mean(axis=1), 'max': acc_df.max(axis=1), 'min': acc_df.min(axis=1)})

In [25]:
acc_df

Unnamed: 0,mean,max,min
0,0.973712,1.0,0.0
1,0.946019,0.977778,0.769231
2,0.963192,0.981982,0.877778
3,0.959116,0.981982,0.577778
4,0.97262,1.0,0.891892
5,0.45098,0.927928,0.27027
6,0.147677,0.306306,0.018018
7,0.974739,0.990991,0.945946
8,0.98048,1.0,0.954955


In [26]:
fig = px.line(test_metric, x="Loop", y="value", color='metric', markers=True)
fig.show()

In [None]:
file_name_lst = ["answers", "emails", "newsgroups", "reviews", "weblogs"]