In [1]:
import pandas as pd
import numpy as np
import os
import time
import random
import collections
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from datasets import Dataset
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from transformers import *
import re
import copy

In [2]:
CFG = {
    'fold_num': 5,
    'seed': 42,
#     'model': 'deepset/xlm-roberta-large-squad2',
    'model': 'squad2/infoxlm-squad2-512',
#     'model': 'squad2/muril-large-squad2',
#     'model': 'google/muril-large-cased',
    'max_length': 512, 
    'doc_stride': 128,
    'epochs': 5, 
    'train_bs': 4, 
    'valid_bs': 8,
    'lr': 5e-6, 
    'weight_decay': 1e-6
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(CFG['seed'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')

ext1 = pd.read_csv('mlqa.csv')
ext2 = pd.read_csv('xquad.csv')

In [5]:
train.language.value_counts()

hindi    746
tamil    368
Name: language, dtype: int64

In [8]:
def convert_answers(r):
    start = r[0]
    text = r[1]
    language = ['hindi', 'tamil'].index(r[2])
    return {
        'answer_start': [start],
        'text': [text],
        'language': [language]
    }


train['answers'] = train[['answer_start', 'answer_text', 'language']].apply(convert_answers, axis=1)
ext1['answers'] = ext1[['answer_start', 'answer_text', 'language']].apply(convert_answers, axis=1)
ext2['answers'] = ext2[['answer_start', 'answer_text', 'language']].apply(convert_answers, axis=1)

train

Unnamed: 0,id,context,question,answer_text,answer_start,language,answers
0,903deec17,ஒரு சாதாரண வளர்ந்த மனிதனுடைய எலும்புக்கூடு பின...,மனித உடலில் எத்தனை எலும்புகள் உள்ளன?,206,53,tamil,"{'answer_start': [53], 'text': ['206'], 'langu..."
1,d9841668c,காளிதாசன் (தேவநாகரி: कालिदास) சமஸ்கிருத இலக்கி...,காளிதாசன் எங்கு பிறந்தார்?,காசுமீரில்,2358,tamil,"{'answer_start': [2358], 'text': ['காசுமீரில்'..."
2,29d154b56,சர் அலெக்ஸாண்டர் ஃபிளெமிங் (Sir Alexander Flem...,பென்சிலின் கண்டுபிடித்தவர் யார்?,சர் அலெக்ஸாண்டர் ஃபிளெமிங்,0,tamil,"{'answer_start': [0], 'text': ['சர் அலெக்ஸாண்ட..."
3,41660850a,"குழந்தையின் அழுகையை நிறுத்தவும், தூங்க வைக்கவ...",தமிழ்நாட்டில் குழந்தைகளை தூங்க வைக்க பாடும் பா...,தாலாட்டு,68,tamil,"{'answer_start': [68], 'text': ['தாலாட்டு'], '..."
4,b29c82c22,சூரியக் குடும்பம் \nசூரியக் குடும்பம் (Solar S...,பூமியின் அருகில் உள்ள விண்மீன் எது?,சூரியனும்,585,tamil,"{'answer_start': [585], 'text': ['சூரியனும்'],..."
...,...,...,...,...,...,...,...
1109,26f356026,स्वामी निगमानन्द परमहंस (18 अगस्त 1880 - 29 नव...,स्वामी निगमानन्द परमहंस के तन्त्र गुरु कौन थे?,बामाक्षेपा,3619,hindi,"{'answer_start': [3619], 'text': ['बामाक्षेपा'..."
1110,31179f1bb,भरत मुनि ने नाट्यशास्त्र नामक प्रसिद्ध ग्रन्थ ...,नित्यशास्त्र किसने लिखा है?,भरत मुनि,0,hindi,"{'answer_start': [0], 'text': ['भरत मुनि'], 'l..."
1111,0d35dc007,अग्नि पंचम (अग्नि-५) भारत की अन्तरमहाद्वीपीय ब...,अग्नि पंचम(५) मिसाइल की लम्बाई कितने मीटर है?,17,155,hindi,"{'answer_start': [155], 'text': ['17'], 'langu..."
1112,7f997884d,"जलाल उद्दीन मोहम्मद अकबर () (१५ अक्तूबर, १५४२-...",मुगल सम्राट अकबर की मृत्यु किस वर्ष में हुई थी?,"२७ अक्तूबर, १६०५",46,hindi,"{'answer_start': [46], 'text': ['२७ अक्तूबर, १..."


In [9]:
def prepare_train_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=CFG['max_length'],
        stride=CFG['doc_stride'],
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

#     print(tokenized_examples)
    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []
    tokenized_examples["has_answer"] = []
    tokenized_examples["language"] = []
    
    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])
            
            language = answers["language"][0]
            tokenized_examples["language"].append(language)
            
            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
#                 if random.random() < 0.5:
#                     tokenized_examples["start_positions"].append(cls_index)
#                     tokenized_examples["end_positions"].append(cls_index)

#                     tokenized_examples["has_answer"].append(0)
#                 else:
#                     tokenized_examples["start_positions"].append(-100)
#                     tokenized_examples["end_positions"].append(-100)

#                     tokenized_examples["has_answer"].append(0)
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)

                tokenized_examples["has_answer"].append(0)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)
                
                tokenized_examples["has_answer"].append(1)


    return tokenized_examples

In [10]:
def prepare_validation_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=CFG['max_length'],
        stride=CFG['doc_stride'],
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [11]:
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size=20, max_answer_length=30, w=0):
    all_start_logits, all_end_logits = raw_predictions
    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
#     print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []
        
        context = example["context"]
        # Looping through all the features associated to the current example.
        for i, feature_index in enumerate(feature_indices):
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]

            # Update minimum null prediction.
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index], #+ scores*w, # - i/len(feature_indices)*w,
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}
        
        # Let's pick our final answer: the best one or the null answer (only for squad_v2)
        predictions[example["id"]] = best_answer["text"]

    return predictions

In [12]:
def cal_jaccard(row): 
    str1 = row[0]
    str2 = row[1]
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))


class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
        
def train_model(model, train_loader, gradient_accumulation_steps=4):
    model.train()
    losses = AverageMeter()
    accs = AverageMeter()
    
    optimizer.zero_grad()
    
    tk = tqdm(train_loader, total=len(train_loader), position=0, leave=True)
    for step, batch in enumerate(tk):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        labels = batch['has_answer'].to(device)
        language = batch['language'].to(device)
        
        with autocast():
            output = model(input_ids, attention_mask)
            start_logits, end_logits = output['start_logits'], output['end_logits']
            
            start_loss = criterion(start_logits, start_positions)
            end_loss = criterion(end_logits, end_positions)
            loss = (start_loss + end_loss) / 2
                
            loss = loss / gradient_accumulation_steps
            
        scaler.scale(loss).backward()
        
        if step % gradient_accumulation_steps == 0 or step == len(train_loader) - 1:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad() 
        scheduler.step()
            
        lr = optimizer.param_groups[-1]['lr']

        losses.update(loss.item()*gradient_accumulation_steps, input_ids.size(0))
        tk.set_postfix(loss=losses.avg, lr=lr)
        
    return losses.avg


def test_model(model, val_loader):
    model.eval()
    
    losses = AverageMeter()
    
    all_start_logits, all_end_logits = [], []
    
    with torch.no_grad():
        tk = tqdm(val_loader, total=len(val_loader), position=0, leave=True)
        for step, batch in enumerate(tk):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            start_positions = batch['start_positions'].to(device)
            end_positions = batch['end_positions'].to(device)
            language = batch['language'].to(device)
        
            output = model(input_ids, attention_mask)
            start_logits, end_logits = output['start_logits'], output['end_logits']
            
            start_loss = criterion(start_logits, start_positions)
            end_loss = criterion(end_logits, end_positions)
            loss = (start_loss + end_loss) / 2
            
            all_start_logits.extend(start_logits.cpu().numpy())
            all_end_logits.extend(end_logits.cpu().numpy())
            
            losses.update(loss.item(), input_ids.size(0))
            tk.set_postfix(loss=losses.avg)
      
    all_start_logits, all_end_logits = np.array(all_start_logits), np.array(all_end_logits)
    
    final_predictions = postprocess_qa_predictions(valid_dataset, validation_features, (all_start_logits, all_end_logits))

    references = [{"id": ex["id"], "answer": ex["answers"]['text'][0]} for ex in valid_dataset]
    res = pd.DataFrame(references)
    res['prediction'] = res['id'].apply(lambda r: final_predictions[r])
    res['jaccard'] = res[['answer', 'prediction']].apply(cal_jaccard, axis=1)
    jaccard = res['jaccard'].mean()

    print(jaccard)
    
    return losses.avg, jaccard

In [None]:
seed_everything(CFG['seed'])

tokenizer = XLMRobertaTokenizerFast.from_pretrained(CFG['model'])
# tokenizer = AutoTokenizer.from_pretrained(CFG['model'])

folds = StratifiedKFold(n_splits=CFG['fold_num'], shuffle=True, random_state=CFG['seed'])\
                    .split(np.arange(train.shape[0]), train['language'].values)

cv = [] 

for fold, (trn_idx, val_idx) in enumerate(folds):
    
    print(fold)
    
    df_train = train.loc[trn_idx].reset_index()
    df_valid = train.loc[val_idx].reset_index()
    
    df_train = pd.concat([df_train, ext1.sample(500), ext2.sample(500)], 0)
                
    train_dataset = Dataset.from_pandas(df_train)
    valid_dataset = Dataset.from_pandas(df_valid)

    tokenized_train_ds = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset.column_names)
    tokenized_valid_ds = valid_dataset.map(prepare_train_features, batched=True, remove_columns=valid_dataset.column_names)
    validation_features = valid_dataset.map(prepare_validation_features, batched=True, remove_columns=valid_dataset.column_names)
    
    train_loader = DataLoader(tokenized_train_ds, shuffle=True, collate_fn=default_data_collator, batch_size=CFG['train_bs'])
    valid_loader = DataLoader(tokenized_valid_ds, shuffle=False, collate_fn=default_data_collator, batch_size=CFG['valid_bs'])

    best_score = 0
    
    model = AutoModelForQuestionAnswering.from_pretrained(CFG['model']).to(device)

    scaler = GradScaler()
    optimizer = AdamW(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
    criterion = nn.CrossEntropyLoss()
    scheduler = get_cosine_schedule_with_warmup(optimizer, len(train_loader), CFG['epochs']*len(train_loader))

    for epoch in range(CFG['epochs']):
        print('epoch:',epoch)
        time.sleep(0.2)

        train_loss = train_model(model, train_loader)
        val_loss, val_score = test_model(model, valid_loader)

        if val_score > best_score:
            best_score = val_score
            torch.save(model.state_dict(), '{}_fold{}.pt'.format(CFG['model'].split('/')[-1], fold))
            
    cv.append(best_score)

cv, np.mean(cv)

0


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


epoch: 0


100%|██████████| 1981/1981 [09:37<00:00,  3.43it/s, loss=1.96, lr=5e-6]   
100%|██████████| 239/239 [00:56<00:00,  4.25it/s, loss=0.259]
100%|██████████| 223/223 [00:06<00:00, 35.45it/s]


0.72199177877429
epoch: 1


100%|██████████| 1981/1981 [09:43<00:00,  3.40it/s, loss=0.453, lr=4.27e-6]
100%|██████████| 239/239 [00:58<00:00,  4.09it/s, loss=0.261]
100%|██████████| 223/223 [00:06<00:00, 34.18it/s]


0.7312559451012366
epoch: 2


100%|██████████| 1981/1981 [09:37<00:00,  3.43it/s, loss=0.316, lr=2.5e-6] 
100%|██████████| 239/239 [00:53<00:00,  4.46it/s, loss=0.25] 
100%|██████████| 223/223 [00:06<00:00, 36.54it/s]


0.710482062780269
epoch: 3


100%|██████████| 1981/1981 [09:23<00:00,  3.52it/s, loss=0.205, lr=7.32e-7]
100%|██████████| 239/239 [00:55<00:00,  4.28it/s, loss=0.289]
100%|██████████| 223/223 [00:06<00:00, 36.99it/s]


0.7192230601983964
epoch: 4


100%|██████████| 1981/1981 [09:30<00:00,  3.47it/s, loss=0.17, lr=0]        
100%|██████████| 239/239 [00:54<00:00,  4.36it/s, loss=0.294]
100%|██████████| 223/223 [00:06<00:00, 36.90it/s]

0.7207178285093084
1





HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


epoch: 0


100%|██████████| 1966/1966 [09:19<00:00,  3.51it/s, loss=2.09, lr=5e-6]   
100%|██████████| 247/247 [00:55<00:00,  4.42it/s, loss=0.26] 
100%|██████████| 223/223 [00:06<00:00, 35.67it/s]


0.696675098693036
epoch: 1


100%|██████████| 1966/1966 [09:22<00:00,  3.50it/s, loss=0.456, lr=4.27e-6]
100%|██████████| 247/247 [00:56<00:00,  4.35it/s, loss=0.257]
100%|██████████| 223/223 [00:06<00:00, 35.16it/s]


0.6734243335140196
epoch: 2


100%|██████████| 1966/1966 [09:22<00:00,  3.49it/s, loss=0.291, lr=2.5e-6] 
100%|██████████| 247/247 [00:56<00:00,  4.34it/s, loss=0.275]
100%|██████████| 223/223 [00:06<00:00, 33.86it/s]


0.6885143755098913
epoch: 3


100%|██████████| 1966/1966 [09:24<00:00,  3.48it/s, loss=0.21, lr=7.32e-7] 
100%|██████████| 247/247 [00:52<00:00,  4.70it/s, loss=0.292]
100%|██████████| 223/223 [00:06<00:00, 35.03it/s]


0.6943177479169926
epoch: 4


100%|██████████| 1966/1966 [09:01<00:00,  3.63it/s, loss=0.158, lr=0]       
100%|██████████| 247/247 [00:53<00:00,  4.66it/s, loss=0.303]
100%|██████████| 223/223 [00:06<00:00, 35.37it/s]


0.7077706627151988
2


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


epoch: 0


100%|██████████| 1983/1983 [09:00<00:00,  3.67it/s, loss=1.99, lr=5e-6]   
100%|██████████| 238/238 [00:50<00:00,  4.73it/s, loss=0.289]
100%|██████████| 223/223 [00:06<00:00, 36.41it/s]


0.7049148073416077
epoch: 1


100%|██████████| 1983/1983 [09:04<00:00,  3.64it/s, loss=0.456, lr=4.27e-6]
100%|██████████| 238/238 [00:50<00:00,  4.67it/s, loss=0.255]
100%|██████████| 223/223 [00:06<00:00, 36.34it/s]


0.7366422068215789
epoch: 2


100%|██████████| 1983/1983 [09:03<00:00,  3.65it/s, loss=0.31, lr=2.5e-6]  
100%|██████████| 238/238 [00:50<00:00,  4.75it/s, loss=0.269]
100%|██████████| 223/223 [00:06<00:00, 35.72it/s]


0.7233654611457301
epoch: 3


100%|██████████| 1983/1983 [09:04<00:00,  3.64it/s, loss=0.203, lr=7.32e-7]
100%|██████████| 238/238 [00:50<00:00,  4.68it/s, loss=0.316]
100%|██████████| 223/223 [00:06<00:00, 36.35it/s]


0.7229383844854697
epoch: 4


100%|██████████| 1983/1983 [09:03<00:00,  3.65it/s, loss=0.176, lr=0]       
100%|██████████| 238/238 [00:50<00:00,  4.70it/s, loss=0.319]
100%|██████████| 223/223 [00:06<00:00, 36.71it/s]


0.7248602294566421
3


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


epoch: 0


100%|██████████| 1916/1916 [08:43<00:00,  3.66it/s, loss=2.06, lr=5e-6]   
100%|██████████| 272/272 [00:57<00:00,  4.73it/s, loss=0.314]
100%|██████████| 223/223 [00:07<00:00, 31.49it/s]


0.7241405082212257
epoch: 1


100%|██████████| 1916/1916 [08:45<00:00,  3.65it/s, loss=0.424, lr=4.27e-6]
100%|██████████| 272/272 [00:57<00:00,  4.71it/s, loss=0.287]
100%|██████████| 223/223 [00:07<00:00, 31.49it/s]


0.7246636771300449
epoch: 2


100%|██████████| 1916/1916 [08:44<00:00,  3.65it/s, loss=0.281, lr=2.5e-6] 
100%|██████████| 272/272 [00:57<00:00,  4.76it/s, loss=0.289]
100%|██████████| 223/223 [00:07<00:00, 31.53it/s]


0.7250373692077728
epoch: 3


100%|██████████| 1916/1916 [08:43<00:00,  3.66it/s, loss=0.198, lr=7.32e-7]
100%|██████████| 272/272 [00:58<00:00,  4.64it/s, loss=0.331]
100%|██████████| 223/223 [00:07<00:00, 31.51it/s]


0.7207238949391417
epoch: 4


100%|██████████| 1916/1916 [08:43<00:00,  3.66it/s, loss=0.157, lr=0]       
100%|██████████| 272/272 [00:57<00:00,  4.77it/s, loss=0.341]
100%|██████████| 223/223 [00:07<00:00, 31.65it/s]


0.7259555840273328
4


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


epoch: 0


 71%|███████   | 1403/1989 [06:20<02:37,  3.72it/s, loss=2.46, lr=3.53e-6]