# Preparation

## Install Dependencies

In [None]:
!pip install datasets
!pip install tokenizers
!pip install transformers
!pip install stanza
# -- Initialize Stanza --
import stanza
stanza.download('en')

## Download Datasets

In [None]:
# -- OPEN-I dataset --
!mkdir -p open_i
!wget -q -N -P open_i https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz
!tar zxf open_i/NLMCXR_reports.tgz -C open_i
import glob
num_files = len(glob.glob("open_i/ecgen-radiology/*.xml"))
print(f'Downloaded {num_files} files from the OpenI dataset into "open_i/ecgen-radiology"')

# -- Transcriptions Dataset --
!mkdir -p medical_transcriptions
!gdown --id 1E0hm3r9bwK8cujyIcOjp_y-ZEPt1HBjn -O medical_transcriptions/medical_transcriptions.zip
!unzip -o medical_transcriptions/medical_transcriptions.zip -d medical_transcriptions
print(f'Downloaded medical transriptions dataset')

Downloaded 3955 files from the OpenI dataset into "open_i/ecgen-radiology"
Downloading...
From: https://drive.google.com/uc?id=1E0hm3r9bwK8cujyIcOjp_y-ZEPt1HBjn
To: /content/medical_transcriptions/medical_transcriptions.zip
100% 5.08M/5.08M [00:00<00:00, 10.5MB/s]
Archive:  medical_transcriptions/medical_transcriptions.zip
  inflating: medical_transcriptions/mtsamples.csv  
Downloaded medical transriptions dataset


# Training a Text Classifier

## Step 1 - Load the Medical Transctions Dataset
We first load the Medical Transcriptions dataset as a huggingface dataset. The dataset is split into train and test set, so the returned type is a datasets.DatasetDict, which acts like a dictionary of datasets.Dataset but provides utility functions, e.g. for mapping all datasets of the dict using the map method.
This step is already implemented.

In [None]:
from typing import List, Tuple
import pandas as pd
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split

def load_transciptions_dataset() -> Tuple[DatasetDict, List[str]]:
    df = pd.read_csv('medical_transcriptions/mtsamples.csv')
    df = df.drop(['Unnamed: 0'],axis=1,)

    counts = df['medical_specialty'].value_counts()
    print(f'Original counts:\n{counts}')
    dropped_specialities = [k for k, v in counts.items() if v < 100]
    for dropped_speciality in dropped_specialities:
      df = df[df['medical_specialty'] != dropped_speciality]
    df.dropna(inplace=True)
    counts = df['medical_specialty'].value_counts()
    print(f'Counts after removing small specialities:\n{counts}')

    df['medical_specialty'] = df['medical_specialty'].astype('category')
    class_names = df['medical_specialty'].cat.categories.tolist()
    print(f'Class names : {class_names}')
    df['medical_specialty'] = df['medical_specialty'].cat.codes

    train, test = train_test_split(df, stratify=df['medical_specialty'], test_size=0.25)
    dataset = DatasetDict({'train': Dataset.from_pandas(train), 'test': Dataset.from_pandas(test)})
    dataset = dataset.remove_columns(['__index_level_0__'])

    return dataset, class_names

In [None]:
dataset, class_names = load_transciptions_dataset()

Original counts:
medical_specialty
 Surgery                          1103
 Consult - History and Phy.        516
 Cardiovascular / Pulmonary        372
 Orthopedic                        355
 Radiology                         273
 General Medicine                  259
 Gastroenterology                  230
 Neurology                         223
 SOAP / Chart / Progress Notes     166
 Obstetrics / Gynecology           160
 Urology                           158
 Discharge Summary                 108
 ENT - Otolaryngology               98
 Neurosurgery                       94
 Hematology - Oncology              90
 Ophthalmology                      83
 Nephrology                         81
 Emergency Room Reports             75
 Pediatrics - Neonatal              70
 Pain Management                    62
 Psychiatry / Psychology            53
 Office Notes                       51
 Podiatry                           47
 Dermatology                        29
 Cosmetic / Plastic Surgery  

In [None]:
print('Training samples: ', len(dataset['train']['transcription']))
print('Test samples: ', len(dataset['test']['transcription']))
print('Columns: ', dataset['train'].features)

Training samples:  2315
Test samples:  772
Columns:  {'description': Value(dtype='string', id=None), 'medical_specialty': Value(dtype='int8', id=None), 'sample_name': Value(dtype='string', id=None), 'transcription': Value(dtype='string', id=None), 'keywords': Value(dtype='string', id=None)}


## Step 2 - Compute sentence embeddings using pre-trained language model
Our goal is to use a pre-trained language model (from the huggingface transformers library) to encode sentences into sentence representations. Therefore, each sentence is tokenized and passed to the language model. The resulting contextualized token representations (i.e. the outputs of the last hidden layer of the language model) are then globally pooled to get sentence-level representations. As we use a pre-trained model, no training is involved in this step.

### Specify pre-trained language model
Decide which pre-trained language model you want to use and specify its name here.
You can search for models at the huggingface model hub: https://huggingface.co/models . You can either use standard models, e.g. BERT, or use biomedcial models.

When you decided for a model, copy the name of the model from the URL (removing only https://huggingface.co/) and insert it as model_name here.
For more reference on how to use the tokenizer have a look at https://huggingface.co/docs/transformers/preprocessing.

In [None]:
model_name = 'google-bert/bert-base-uncased'

In [None]:
# Load the tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(tokenizer)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

BertTokenizerFast(name_or_path='google-bert/bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}


In [None]:
# Load the model and inspect it
from transformers import AutoModel
model = AutoModel.from_pretrained(model_name)
print(model)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

### Experiment with tokenizer and language model

In [None]:
# Tokenize some example text and inspect the results to understand the outputs of the tokenizer
sample_text = dataset['train'][:2]['transcription']  # batch of 2 samples of text
print('example text: ', sample_text)
toknized = tokenizer(sample_text, truncation=True, max_length=512)
print(toknized.keys())
print('input ids: ', toknized['input_ids'])
print('attention mask: ', toknized['attention_mask'])
print('length: ', len(toknized['input_ids'][0]))

example text:  ['PREOPERATIVE DIAGNOSIS: , Herniated nucleus pulposus C5-C6.,POSTOPERATIVE DIAGNOSIS: , Herniated nucleus pulposus C5-C6.,PROCEDURE:,  Anterior cervical discectomy fusion C5-C6 followed by instrumentation C5-C6 with titanium dynamic plating system, Aesculap.  Operating microscope was used for both illumination and magnification.,FIRST ASSISTANT: , Nurse practitioner.,PROCEDURE IN DETAIL: , The patient was placed in supine position.  The neck was prepped and draped in the usual fashion for anterior discectomy and fusion.  An incision was made midline to the anterior body of the sternocleidomastoid at C5-C6 level.  The skin, subcutaneous tissue, and platysma muscle was divided exposing the carotid sheath, which was retracted laterally.  Trachea and esophagus were retracted medially.  After placing the self-retaining retractors with the longus colli muscles having been dissected away from the vertebral bodies at C5 and C6 and confirming our position with intraoperative x-r

In [None]:
# Feed some example data to understand the outputs of the language model
# pad the batch and convert it to Pytorch
# return_tensors='pt' => return the tokenized values as PyTorch tensors instead of lists
x = tokenizer.pad(toknized, return_tensors='pt')
print(x.keys())
print('input ids: ', x['input_ids'])
print('attention mask: ', x['attention_mask'])
print('shape: ', x['input_ids'].shape)

# Encode the tokenized and padded input
results = model(**x)
print('output shape: ', results.last_hidden_state.shape)  # (batch_size x num_tokens x d_hidden)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
input ids:  tensor([[  101,  3653, 25918,  ...,  2008,  2001,   102],
        [  101,  3653, 25918,  ..., 19470,  1012,   102]])
attention mask:  tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])
shape:  torch.Size([2, 512])
output shape:  torch.Size([2, 512, 768])


### Tokenize the dataset
We now tokenize the whole dataset using the batched map-method.

TODO: Implement the tokenize_batch function using the tokenizer. Note: do not pad the data yet but truncate it to a max length of 512.

In [None]:
def tokenize_batch(text_batch: List[str]):
  # implement tokenize batch using the tokenizer
  return tokenizer(text_batch, truncation=True, max_length=512, return_tensors=None)

In [None]:
# Apply tokenize_batch to dataset
tokenized_dataset = dataset.map(lambda examples: tokenize_batch(examples["transcription"]), batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['sample_name', 'transcription', 'description', 'keywords']).rename_column('medical_specialty', 'labels')
print('Columns after tokenization', list(tokenized_dataset['train'].features.keys()))

Map:   0%|          | 0/2315 [00:00<?, ? examples/s]

Map:   0%|          | 0/772 [00:00<?, ? examples/s]

Columns after tokenization ['labels', 'input_ids', 'token_type_ids', 'attention_mask']


### Token Pooling
Define Pooling functions to compute sentence embeddings from token outputs.

implement three possible pooling functions
- CLS: use the output of the [CLS] token only and ignore the other tokens.
- max/avg: globally max/avg pool over all token outputs, but make sure to ignore padding tokens.

In [None]:
import torch

def CLS_pool(hidden_state: torch.FloatTensor, attention_mask: torch.BoolTensor):
    """
    Returns only the hidden state of the [CLS] token
    :param hidden_state: (N x M x d_hidden) where N is the batch size and M is the number of tokens
    :param attention_mask: (N x M), True for "real" tokens, False for padding tokens.
    :return (N x d_hidden)
    """
    return hidden_state[:, 0, :]  # Select the hidden state of the [CLS] token (first token)

def max_pool(hidden_state: torch.FloatTensor, attention_mask: torch.BoolTensor):
    """
    Globally pools the hidden states over all tokens using max pooling.
    :param hidden_state: (N x M x d_hidden) where N is the batch size and M is the number of tokens
    :param attention_mask: (N x M), True for "real" tokens, False for padding tokens.
    :return (N x d_hidden)
    """
    expanded_mask = attention_mask.unsqueeze(-1).expand_as(hidden_state)  # Expand mask for all d_hidden dimensions
    masked_hidden_state = hidden_state.where(expanded_mask, torch.tensor(float('-inf')).to(hidden_state.device))
    max_pooled = torch.max(masked_hidden_state, dim=1).values
    return max_pooled

def avg_pool(hidden_state: torch.FloatTensor, attention_mask: torch.BoolTensor):
    """
    Globally pools the hidden states over all tokens using average pooling.
    :param hidden_state: (N x M x d_hidden) where N is the batch size and M is the number of tokens
    :param attention_mask: (N x M), True for "real" tokens, False for padding tokens.
    :return (N x d_hidden)
    """
    expanded_mask = attention_mask.unsqueeze(-1).expand_as(hidden_state)  # Expand mask for all d_hidden dimensions
    masked_hidden_state = hidden_state * expanded_mask
    sum_pooled = torch.sum(masked_hidden_state, dim=1)
    token_counts = expanded_mask.sum(dim=1)  # Avoid division by zero; sum of True values per sample
    avg_pooled = sum_pooled / token_counts.clamp(min=1)  # Use clamp to avoid division by zero
    return avg_pooled


### Sentence Embedder
The sentence embedder takes a tokenized input, uses the language model to compute token representations (i.e. the last hidden_state of the language model) and then uses a pooling function to compute a single sentence representation.

implement the sentence embedder forward method.

In [None]:
from torch import nn
from typing import Dict

class SentenceEmbedder(nn.Module):
    def __init__(self, model, pool):
        super().__init__()
        self.model = model  # Huggingface language model, e.g., AutoModel.from_pretrained('bert-base-uncased')
        self.pool = pool    # Pooling function (CLS_pool, max_pool, avg_pool)

    def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Forward pass for generating sentence embeddings.
        :param x: dict containing:
            - input_ids: torch.Tensor (N x M)
            - attention_mask: torch.Tensor (N x M)
        :return: Sentence embedding for each sample of shape (N x d_hidden)
        """
        # Use the model to get the last hidden state of the language model
        outputs = self.model(input_ids=x['input_ids'], attention_mask=x['attention_mask'])

        # Apply the pooling function to the last hidden states and the attention mask
        pooled_output = self.pool(outputs.last_hidden_state, x['attention_mask'])

        return pooled_output


now decide which pooling function you want to use

In [None]:
device = 'cuda:0'
dataset_name = 'encoded_transciptions_avg_pool'
pool = avg_pool
sentence_embedder = SentenceEmbedder(model, pool).to(device=device)

 Now run the sentence embedder.

 The resulting dataset will then contain a sentence_embedding column.

In [None]:
def embed_sentence(batch):
  model_input = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}
  model_input = tokenizer.pad(model_input, return_tensors='pt').to(device=device)

  with torch.no_grad():
    sentence_embeddings = sentence_embedder(model_input)
  return {'sentence_embedding': sentence_embeddings.detach().cpu().numpy()}

encoded_dataset = tokenized_dataset.map(embed_sentence, batched=True, batch_size=64)
encoded_dataset.save_to_disk(dataset_name)

Map:   0%|          | 0/2315 [00:00<?, ? examples/s]

Map:   0%|          | 0/772 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2315 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/772 [00:00<?, ? examples/s]

## Step 5 - Train Classification Model on Sentence Embeddings
Now we have a single sentence representation vector for each sentence. We now learn a simple classifier based on a MLP (i.e. a 2-layer fully-connected neural network). We first project each sentence embedding from d_hidden to d_mlp, apply ReLU (or another non-linearity) and then project the resulting vector into num_classes where we then apply the cross entropy loss. The classification head will then be learned based on the training data. This is the only learned component in our model.


In [None]:
from torch import nn
import datasets
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import torch.nn.functional as F

class ClassificationHead(nn.Module):
    def __init__(self, d_hidden: int, d_mlp: int, num_classes: int):
        super().__init__()
        # Define the layers of the MLP
        self.fc1 = nn.Linear(d_hidden, d_mlp)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(d_mlp, num_classes)

        # Define the loss function - Cross Entropy Loss
        self.loss_function = nn.CrossEntropyLoss()

    def forward(self, x, y_true):
        """
        :param x: Sentence embeddings (N x d_hidden)
        :param y_true: Target classes (N)
        :return: Tuple of predictions and loss
        """
        # First fully-connected layer + ReLU activation
        x = self.relu(self.fc1(x))
        # Second fully-connected layer to get logits
        logits = self.fc2(x)

        # Compute the cross-entropy loss
        loss = self.loss_function(logits, y_true)
        # Compute the predictions (class with highest score)
        y_pred = torch.argmax(logits, dim=1)

        return y_pred, loss


Now train the classification head on the sentence embeddings.
The training is already implemented.

In [None]:
device = 'cuda:0'
dataset_name = 'encoded_transciptions_avg_pool'
num_epochs = 300
lr = 1e-4
weight_decay = 1e-4
d_hidden = 768  # must match the language model hidden output
d_mlp = 1024
batch_size = 128

In [None]:
# Now run the training...

print('num classes: ', len(class_names))
encoded_dataset = load_from_disk(dataset_name)
# make sure PyTorch Tensors are returned
encoded_dataset.set_format('pt')

# remove all columns except sentence_embedding and labels
columns_to_remove = set(encoded_dataset['train'].column_names) - {'sentence_embedding', 'labels'}
encoded_dataset = encoded_dataset.remove_columns(columns_to_remove)

train_data_loader = DataLoader(encoded_dataset['train'], batch_size=batch_size, shuffle=True)
test_data_loader = DataLoader(encoded_dataset['test'], batch_size=batch_size, shuffle=False)

classification_model = ClassificationHead(d_hidden=d_hidden,
                                          d_mlp=d_mlp,
                                          num_classes=len(class_names))
classification_model = classification_model.to(device=device)
optimizer = torch.optim.Adam(classification_model.parameters(), lr=lr, weight_decay=weight_decay)

classification_model.train()
print('Training...')
for epoch in range(num_epochs):
  train_loss = []
  for train_batch in tqdm(train_data_loader):
    x = train_batch['sentence_embedding'].to(device=device)
    y_true = train_batch['labels'].to(device=device)

    y_pred, loss = classification_model(x, y_true)
    train_loss.append(loss.item())

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
  print('train loss: ', torch.mean(torch.tensor(train_loss)).item())

classification_model.eval()
print('Testing...')
f1_metric = datasets.load_metric("f1")
with torch.no_grad():
  test_loss = []
  for test_batch in tqdm(test_data_loader):
    x = test_batch['sentence_embedding'].to(device=device)
    y_true = test_batch['labels'].to(device=device)

    y_pred, loss = classification_model(x, y_true)

    f1_metric.add_batch(predictions=y_pred, references=y_true)
    test_loss.append(loss)
print('\ntest loss: ', torch.mean(torch.tensor(test_loss)).item())
print('F1 (Macro): ', f1_metric.compute(average="macro"))

num classes:  12
Training...


100%|██████████| 19/19 [00:00<00:00, 69.95it/s]


train loss:  2.255821943283081


100%|██████████| 19/19 [00:00<00:00, 70.63it/s]


train loss:  2.091909408569336


100%|██████████| 19/19 [00:00<00:00, 81.66it/s]


train loss:  2.0449535846710205


100%|██████████| 19/19 [00:00<00:00, 96.94it/s] 


train loss:  1.96532142162323


100%|██████████| 19/19 [00:00<00:00, 76.00it/s]


train loss:  1.8754944801330566


100%|██████████| 19/19 [00:00<00:00, 70.31it/s]


train loss:  1.8075910806655884


100%|██████████| 19/19 [00:00<00:00, 74.43it/s]


train loss:  1.7510221004486084


100%|██████████| 19/19 [00:00<00:00, 69.20it/s]


train loss:  1.7273386716842651


100%|██████████| 19/19 [00:00<00:00, 137.11it/s]


train loss:  1.6792590618133545


100%|██████████| 19/19 [00:00<00:00, 133.19it/s]


train loss:  1.6214385032653809


100%|██████████| 19/19 [00:00<00:00, 139.25it/s]


train loss:  1.588261604309082


100%|██████████| 19/19 [00:00<00:00, 137.20it/s]


train loss:  1.581997275352478


100%|██████████| 19/19 [00:00<00:00, 149.49it/s]


train loss:  1.566233515739441


100%|██████████| 19/19 [00:00<00:00, 150.74it/s]


train loss:  1.5160290002822876


100%|██████████| 19/19 [00:00<00:00, 145.28it/s]


train loss:  1.509777307510376


100%|██████████| 19/19 [00:00<00:00, 151.57it/s]


train loss:  1.4418513774871826


100%|██████████| 19/19 [00:00<00:00, 148.02it/s]


train loss:  1.4691624641418457


100%|██████████| 19/19 [00:00<00:00, 154.11it/s]


train loss:  1.4433153867721558


100%|██████████| 19/19 [00:00<00:00, 143.44it/s]


train loss:  1.4198007583618164


100%|██████████| 19/19 [00:00<00:00, 155.80it/s]


train loss:  1.3809489011764526


100%|██████████| 19/19 [00:00<00:00, 153.29it/s]


train loss:  1.4073797464370728


100%|██████████| 19/19 [00:00<00:00, 155.04it/s]


train loss:  1.3802127838134766


100%|██████████| 19/19 [00:00<00:00, 152.15it/s]


train loss:  1.3549822568893433


100%|██████████| 19/19 [00:00<00:00, 153.68it/s]


train loss:  1.3714600801467896


100%|██████████| 19/19 [00:00<00:00, 151.92it/s]


train loss:  1.3371648788452148


100%|██████████| 19/19 [00:00<00:00, 148.46it/s]


train loss:  1.3243207931518555


100%|██████████| 19/19 [00:00<00:00, 139.96it/s]


train loss:  1.292508602142334


100%|██████████| 19/19 [00:00<00:00, 150.41it/s]


train loss:  1.2961468696594238


100%|██████████| 19/19 [00:00<00:00, 152.37it/s]


train loss:  1.2886029481887817


100%|██████████| 19/19 [00:00<00:00, 150.56it/s]


train loss:  1.2712528705596924


100%|██████████| 19/19 [00:00<00:00, 145.36it/s]


train loss:  1.2760359048843384


100%|██████████| 19/19 [00:00<00:00, 145.59it/s]


train loss:  1.2668946981430054


100%|██████████| 19/19 [00:00<00:00, 151.85it/s]


train loss:  1.2561465501785278


100%|██████████| 19/19 [00:00<00:00, 131.59it/s]


train loss:  1.2434298992156982


100%|██████████| 19/19 [00:00<00:00, 151.35it/s]


train loss:  1.248634934425354


100%|██████████| 19/19 [00:00<00:00, 148.76it/s]


train loss:  1.2349348068237305


100%|██████████| 19/19 [00:00<00:00, 151.41it/s]


train loss:  1.217434287071228


100%|██████████| 19/19 [00:00<00:00, 152.49it/s]


train loss:  1.2294785976409912


100%|██████████| 19/19 [00:00<00:00, 148.20it/s]


train loss:  1.2086535692214966


100%|██████████| 19/19 [00:00<00:00, 152.72it/s]


train loss:  1.2063801288604736


100%|██████████| 19/19 [00:00<00:00, 137.38it/s]


train loss:  1.2004867792129517


100%|██████████| 19/19 [00:00<00:00, 150.85it/s]


train loss:  1.207106590270996


100%|██████████| 19/19 [00:00<00:00, 151.49it/s]


train loss:  1.1952412128448486


100%|██████████| 19/19 [00:00<00:00, 148.87it/s]


train loss:  1.203960657119751


100%|██████████| 19/19 [00:00<00:00, 150.65it/s]


train loss:  1.1841567754745483


100%|██████████| 19/19 [00:00<00:00, 152.82it/s]


train loss:  1.1772350072860718


100%|██████████| 19/19 [00:00<00:00, 142.06it/s]


train loss:  1.1531795263290405


100%|██████████| 19/19 [00:00<00:00, 137.61it/s]


train loss:  1.1793744564056396


100%|██████████| 19/19 [00:00<00:00, 225.51it/s]


train loss:  1.1609201431274414


100%|██████████| 19/19 [00:00<00:00, 232.15it/s]


train loss:  1.1932176351547241


100%|██████████| 19/19 [00:00<00:00, 227.23it/s]


train loss:  1.1574652194976807


100%|██████████| 19/19 [00:00<00:00, 232.81it/s]


train loss:  1.149408221244812


100%|██████████| 19/19 [00:00<00:00, 230.50it/s]


train loss:  1.1598418951034546


100%|██████████| 19/19 [00:00<00:00, 233.53it/s]


train loss:  1.1918622255325317


100%|██████████| 19/19 [00:00<00:00, 234.50it/s]


train loss:  1.1311801671981812


100%|██████████| 19/19 [00:00<00:00, 193.27it/s]


train loss:  1.1692204475402832


100%|██████████| 19/19 [00:00<00:00, 231.37it/s]


train loss:  1.1310274600982666


100%|██████████| 19/19 [00:00<00:00, 234.57it/s]


train loss:  1.1379390954971313


100%|██████████| 19/19 [00:00<00:00, 211.29it/s]


train loss:  1.118511438369751


100%|██████████| 19/19 [00:00<00:00, 221.56it/s]


train loss:  1.119810938835144


100%|██████████| 19/19 [00:00<00:00, 234.93it/s]


train loss:  1.1249099969863892


100%|██████████| 19/19 [00:00<00:00, 199.20it/s]


train loss:  1.121363639831543


100%|██████████| 19/19 [00:00<00:00, 223.90it/s]


train loss:  1.1092766523361206


100%|██████████| 19/19 [00:00<00:00, 234.02it/s]


train loss:  1.1051421165466309


100%|██████████| 19/19 [00:00<00:00, 218.13it/s]


train loss:  1.1152985095977783


100%|██████████| 19/19 [00:00<00:00, 234.32it/s]


train loss:  1.0834541320800781


100%|██████████| 19/19 [00:00<00:00, 234.97it/s]


train loss:  1.0856136083602905


100%|██████████| 19/19 [00:00<00:00, 230.29it/s]


train loss:  1.0983511209487915


100%|██████████| 19/19 [00:00<00:00, 227.65it/s]


train loss:  1.0811489820480347


100%|██████████| 19/19 [00:00<00:00, 203.31it/s]


train loss:  1.091950535774231


100%|██████████| 19/19 [00:00<00:00, 234.80it/s]


train loss:  1.0735946893692017


100%|██████████| 19/19 [00:00<00:00, 235.63it/s]


train loss:  1.0865819454193115


100%|██████████| 19/19 [00:00<00:00, 234.27it/s]


train loss:  1.0688549280166626


100%|██████████| 19/19 [00:00<00:00, 235.28it/s]


train loss:  1.0731180906295776


100%|██████████| 19/19 [00:00<00:00, 235.01it/s]


train loss:  1.0645990371704102


100%|██████████| 19/19 [00:00<00:00, 231.88it/s]


train loss:  1.0537797212600708


100%|██████████| 19/19 [00:00<00:00, 227.25it/s]


train loss:  1.0688759088516235


100%|██████████| 19/19 [00:00<00:00, 229.89it/s]


train loss:  1.0444446802139282


100%|██████████| 19/19 [00:00<00:00, 176.27it/s]


train loss:  1.060491681098938


100%|██████████| 19/19 [00:00<00:00, 158.00it/s]


train loss:  1.033543348312378


100%|██████████| 19/19 [00:00<00:00, 153.82it/s]


train loss:  1.0661932229995728


100%|██████████| 19/19 [00:00<00:00, 156.42it/s]


train loss:  1.074342131614685


100%|██████████| 19/19 [00:00<00:00, 151.39it/s]


train loss:  1.0534933805465698


100%|██████████| 19/19 [00:00<00:00, 147.29it/s]


train loss:  1.040039300918579


100%|██████████| 19/19 [00:00<00:00, 152.65it/s]


train loss:  1.03398859500885


100%|██████████| 19/19 [00:00<00:00, 150.83it/s]


train loss:  1.0331388711929321


100%|██████████| 19/19 [00:00<00:00, 158.40it/s]


train loss:  1.0354468822479248


100%|██████████| 19/19 [00:00<00:00, 163.39it/s]


train loss:  1.0582232475280762


100%|██████████| 19/19 [00:00<00:00, 155.49it/s]


train loss:  1.0279682874679565


100%|██████████| 19/19 [00:00<00:00, 152.78it/s]


train loss:  1.0299062728881836


100%|██████████| 19/19 [00:00<00:00, 148.35it/s]


train loss:  1.0386883020401


100%|██████████| 19/19 [00:00<00:00, 155.79it/s]


train loss:  1.0101401805877686


100%|██████████| 19/19 [00:00<00:00, 143.59it/s]


train loss:  1.030332088470459


100%|██████████| 19/19 [00:00<00:00, 149.06it/s]


train loss:  1.0272667407989502


100%|██████████| 19/19 [00:00<00:00, 150.17it/s]


train loss:  1.011596441268921


100%|██████████| 19/19 [00:00<00:00, 142.54it/s]


train loss:  1.0127105712890625


100%|██████████| 19/19 [00:00<00:00, 143.98it/s]


train loss:  1.0040249824523926


100%|██████████| 19/19 [00:00<00:00, 155.07it/s]


train loss:  0.9958978891372681


100%|██████████| 19/19 [00:00<00:00, 146.45it/s]


train loss:  1.0002094507217407


100%|██████████| 19/19 [00:00<00:00, 145.58it/s]


train loss:  1.0197385549545288


100%|██████████| 19/19 [00:00<00:00, 151.31it/s]


train loss:  1.0185123682022095


100%|██████████| 19/19 [00:00<00:00, 133.97it/s]


train loss:  0.9995478391647339


100%|██████████| 19/19 [00:00<00:00, 229.84it/s]


train loss:  0.986818790435791


100%|██████████| 19/19 [00:00<00:00, 234.96it/s]


train loss:  0.9895590543746948


100%|██████████| 19/19 [00:00<00:00, 199.61it/s]


train loss:  0.9899032711982727


100%|██████████| 19/19 [00:00<00:00, 224.45it/s]


train loss:  0.9767133593559265


100%|██████████| 19/19 [00:00<00:00, 218.98it/s]


train loss:  0.9912810325622559


100%|██████████| 19/19 [00:00<00:00, 224.31it/s]


train loss:  0.9805412292480469


100%|██████████| 19/19 [00:00<00:00, 225.09it/s]


train loss:  0.9797951579093933


100%|██████████| 19/19 [00:00<00:00, 226.68it/s]


train loss:  0.9766833782196045


100%|██████████| 19/19 [00:00<00:00, 220.25it/s]


train loss:  0.9738783836364746


100%|██████████| 19/19 [00:00<00:00, 212.44it/s]


train loss:  0.9604085683822632


100%|██████████| 19/19 [00:00<00:00, 228.03it/s]


train loss:  0.9694944620132446


100%|██████████| 19/19 [00:00<00:00, 220.67it/s]


train loss:  0.9825400114059448


100%|██████████| 19/19 [00:00<00:00, 230.51it/s]


train loss:  0.9696716666221619


100%|██████████| 19/19 [00:00<00:00, 198.90it/s]


train loss:  0.9660655856132507


100%|██████████| 19/19 [00:00<00:00, 227.85it/s]


train loss:  0.9616649150848389


100%|██████████| 19/19 [00:00<00:00, 228.97it/s]


train loss:  0.9552592039108276


100%|██████████| 19/19 [00:00<00:00, 215.69it/s]


train loss:  0.9512593150138855


100%|██████████| 19/19 [00:00<00:00, 223.99it/s]


train loss:  0.9508522152900696


100%|██████████| 19/19 [00:00<00:00, 231.71it/s]


train loss:  0.9548988342285156


100%|██████████| 19/19 [00:00<00:00, 226.30it/s]


train loss:  0.9785951375961304


100%|██████████| 19/19 [00:00<00:00, 218.42it/s]


train loss:  0.9446012377738953


100%|██████████| 19/19 [00:00<00:00, 235.04it/s]


train loss:  0.9402454495429993


100%|██████████| 19/19 [00:00<00:00, 219.58it/s]


train loss:  0.9611430168151855


100%|██████████| 19/19 [00:00<00:00, 223.68it/s]


train loss:  0.961022138595581


100%|██████████| 19/19 [00:00<00:00, 207.13it/s]


train loss:  0.9335312843322754


100%|██████████| 19/19 [00:00<00:00, 221.05it/s]


train loss:  0.9322760701179504


100%|██████████| 19/19 [00:00<00:00, 226.26it/s]


train loss:  0.9425264596939087


100%|██████████| 19/19 [00:00<00:00, 225.02it/s]


train loss:  0.9360963106155396


100%|██████████| 19/19 [00:00<00:00, 233.14it/s]


train loss:  0.9403516054153442


100%|██████████| 19/19 [00:00<00:00, 222.15it/s]


train loss:  0.9366644024848938


100%|██████████| 19/19 [00:00<00:00, 206.56it/s]


train loss:  0.9355835914611816


100%|██████████| 19/19 [00:00<00:00, 213.90it/s]


train loss:  0.9270821809768677


100%|██████████| 19/19 [00:00<00:00, 213.10it/s]


train loss:  0.9263724088668823


100%|██████████| 19/19 [00:00<00:00, 229.82it/s]


train loss:  0.924820065498352


100%|██████████| 19/19 [00:00<00:00, 199.13it/s]


train loss:  0.9093944430351257


100%|██████████| 19/19 [00:00<00:00, 227.45it/s]


train loss:  0.9118425250053406


100%|██████████| 19/19 [00:00<00:00, 219.49it/s]


train loss:  0.9298409819602966


100%|██████████| 19/19 [00:00<00:00, 213.57it/s]


train loss:  0.9159772992134094


100%|██████████| 19/19 [00:00<00:00, 223.46it/s]


train loss:  0.9185357093811035


100%|██████████| 19/19 [00:00<00:00, 223.47it/s]


train loss:  0.9252820014953613


100%|██████████| 19/19 [00:00<00:00, 221.19it/s]


train loss:  0.9086796641349792


100%|██████████| 19/19 [00:00<00:00, 225.34it/s]


train loss:  0.9193949699401855


100%|██████████| 19/19 [00:00<00:00, 222.57it/s]


train loss:  0.9102227091789246


100%|██████████| 19/19 [00:00<00:00, 212.87it/s]


train loss:  0.9137718081474304


100%|██████████| 19/19 [00:00<00:00, 200.59it/s]


train loss:  0.9110134243965149


100%|██████████| 19/19 [00:00<00:00, 176.29it/s]


train loss:  0.8965411186218262


100%|██████████| 19/19 [00:00<00:00, 219.29it/s]


train loss:  0.9128002524375916


100%|██████████| 19/19 [00:00<00:00, 226.77it/s]


train loss:  0.9137137532234192


100%|██████████| 19/19 [00:00<00:00, 225.30it/s]


train loss:  0.8994444012641907


100%|██████████| 19/19 [00:00<00:00, 217.65it/s]


train loss:  0.8979015350341797


100%|██████████| 19/19 [00:00<00:00, 223.84it/s]


train loss:  0.9081873893737793


100%|██████████| 19/19 [00:00<00:00, 216.78it/s]


train loss:  0.9007896780967712


100%|██████████| 19/19 [00:00<00:00, 227.51it/s]


train loss:  0.8981582522392273


100%|██████████| 19/19 [00:00<00:00, 224.04it/s]


train loss:  0.8923244476318359


100%|██████████| 19/19 [00:00<00:00, 216.85it/s]


train loss:  0.8826159238815308


100%|██████████| 19/19 [00:00<00:00, 212.68it/s]


train loss:  0.889491617679596


100%|██████████| 19/19 [00:00<00:00, 210.80it/s]


train loss:  0.8741374015808105


100%|██████████| 19/19 [00:00<00:00, 227.25it/s]


train loss:  0.9057384133338928


100%|██████████| 19/19 [00:00<00:00, 217.09it/s]


train loss:  0.8795380592346191


100%|██████████| 19/19 [00:00<00:00, 232.12it/s]


train loss:  0.8781461119651794


100%|██████████| 19/19 [00:00<00:00, 223.96it/s]


train loss:  0.8890068531036377


100%|██████████| 19/19 [00:00<00:00, 190.16it/s]


train loss:  0.8939443826675415


100%|██████████| 19/19 [00:00<00:00, 222.71it/s]


train loss:  0.8869158029556274


100%|██████████| 19/19 [00:00<00:00, 222.93it/s]


train loss:  0.8705786466598511


100%|██████████| 19/19 [00:00<00:00, 225.15it/s]


train loss:  0.8774036765098572


100%|██████████| 19/19 [00:00<00:00, 228.43it/s]


train loss:  0.8851466774940491


100%|██████████| 19/19 [00:00<00:00, 215.35it/s]


train loss:  0.8676345944404602


100%|██████████| 19/19 [00:00<00:00, 221.16it/s]


train loss:  0.8893921375274658


100%|██████████| 19/19 [00:00<00:00, 223.19it/s]


train loss:  0.872392475605011


100%|██████████| 19/19 [00:00<00:00, 232.24it/s]


train loss:  0.8603103756904602


100%|██████████| 19/19 [00:00<00:00, 223.32it/s]


train loss:  0.8624841570854187


100%|██████████| 19/19 [00:00<00:00, 229.20it/s]


train loss:  0.8618234395980835


100%|██████████| 19/19 [00:00<00:00, 227.01it/s]


train loss:  0.8608819842338562


100%|██████████| 19/19 [00:00<00:00, 223.00it/s]


train loss:  0.8760985136032104


100%|██████████| 19/19 [00:00<00:00, 225.63it/s]


train loss:  0.8601768612861633


100%|██████████| 19/19 [00:00<00:00, 214.34it/s]


train loss:  0.8644566535949707


100%|██████████| 19/19 [00:00<00:00, 224.16it/s]


train loss:  0.8645806312561035


100%|██████████| 19/19 [00:00<00:00, 209.27it/s]


train loss:  0.8542060852050781


100%|██████████| 19/19 [00:00<00:00, 221.64it/s]


train loss:  0.8500015139579773


100%|██████████| 19/19 [00:00<00:00, 227.00it/s]


train loss:  0.8392937183380127


100%|██████████| 19/19 [00:00<00:00, 218.03it/s]


train loss:  0.8569020628929138


100%|██████████| 19/19 [00:00<00:00, 226.24it/s]


train loss:  0.8441601991653442


100%|██████████| 19/19 [00:00<00:00, 225.23it/s]


train loss:  0.8524492383003235


100%|██████████| 19/19 [00:00<00:00, 218.25it/s]


train loss:  0.8516287803649902


100%|██████████| 19/19 [00:00<00:00, 216.42it/s]


train loss:  0.8434634208679199


100%|██████████| 19/19 [00:00<00:00, 225.35it/s]


train loss:  0.8576363921165466


100%|██████████| 19/19 [00:00<00:00, 226.13it/s]


train loss:  0.8434269428253174


100%|██████████| 19/19 [00:00<00:00, 221.66it/s]


train loss:  0.8727901577949524


100%|██████████| 19/19 [00:00<00:00, 201.44it/s]


train loss:  0.8421608805656433


100%|██████████| 19/19 [00:00<00:00, 222.02it/s]


train loss:  0.8646979331970215


100%|██████████| 19/19 [00:00<00:00, 226.32it/s]


train loss:  0.8497609496116638


100%|██████████| 19/19 [00:00<00:00, 229.01it/s]


train loss:  0.85672527551651


100%|██████████| 19/19 [00:00<00:00, 222.16it/s]


train loss:  0.8406977653503418


100%|██████████| 19/19 [00:00<00:00, 220.21it/s]


train loss:  0.8387427926063538


100%|██████████| 19/19 [00:00<00:00, 224.88it/s]


train loss:  0.8309216499328613


100%|██████████| 19/19 [00:00<00:00, 223.09it/s]


train loss:  0.858675479888916


100%|██████████| 19/19 [00:00<00:00, 224.86it/s]


train loss:  0.8369419574737549


100%|██████████| 19/19 [00:00<00:00, 225.31it/s]


train loss:  0.8208422064781189


100%|██████████| 19/19 [00:00<00:00, 223.29it/s]


train loss:  0.8179783821105957


100%|██████████| 19/19 [00:00<00:00, 205.20it/s]


train loss:  0.8449039459228516


100%|██████████| 19/19 [00:00<00:00, 220.72it/s]


train loss:  0.8205258846282959


100%|██████████| 19/19 [00:00<00:00, 225.31it/s]


train loss:  0.839770495891571


100%|██████████| 19/19 [00:00<00:00, 223.60it/s]


train loss:  0.8429250121116638


100%|██████████| 19/19 [00:00<00:00, 225.13it/s]


train loss:  0.8369588255882263


100%|██████████| 19/19 [00:00<00:00, 218.02it/s]


train loss:  0.8383029103279114


100%|██████████| 19/19 [00:00<00:00, 153.41it/s]


train loss:  0.8332257270812988


100%|██████████| 19/19 [00:00<00:00, 153.07it/s]


train loss:  0.8114460706710815


100%|██████████| 19/19 [00:00<00:00, 146.64it/s]


train loss:  0.8287160396575928


100%|██████████| 19/19 [00:00<00:00, 149.49it/s]


train loss:  0.8244792819023132


100%|██████████| 19/19 [00:00<00:00, 145.46it/s]


train loss:  0.8204066157341003


100%|██████████| 19/19 [00:00<00:00, 158.92it/s]


train loss:  0.8287147879600525


100%|██████████| 19/19 [00:00<00:00, 162.32it/s]


train loss:  0.83499675989151


100%|██████████| 19/19 [00:00<00:00, 154.37it/s]


train loss:  0.8082047700881958


100%|██████████| 19/19 [00:00<00:00, 155.16it/s]


train loss:  0.8020243048667908


100%|██████████| 19/19 [00:00<00:00, 161.29it/s]


train loss:  0.8145638704299927


100%|██████████| 19/19 [00:00<00:00, 156.92it/s]


train loss:  0.8090499043464661


100%|██████████| 19/19 [00:00<00:00, 156.56it/s]


train loss:  0.8130499720573425


100%|██████████| 19/19 [00:00<00:00, 140.39it/s]


train loss:  0.8130764961242676


100%|██████████| 19/19 [00:00<00:00, 147.02it/s]


train loss:  0.8095650672912598


100%|██████████| 19/19 [00:00<00:00, 154.02it/s]


train loss:  0.8125994801521301


100%|██████████| 19/19 [00:00<00:00, 136.75it/s]


train loss:  0.8080109357833862


100%|██████████| 19/19 [00:00<00:00, 134.77it/s]


train loss:  0.812744140625


100%|██████████| 19/19 [00:00<00:00, 139.30it/s]


train loss:  0.8043692708015442


100%|██████████| 19/19 [00:00<00:00, 144.73it/s]


train loss:  0.8212215304374695


100%|██████████| 19/19 [00:00<00:00, 142.58it/s]


train loss:  0.8397123217582703


100%|██████████| 19/19 [00:00<00:00, 140.82it/s]


train loss:  0.8031935691833496


100%|██████████| 19/19 [00:00<00:00, 135.41it/s]


train loss:  0.801294207572937


100%|██████████| 19/19 [00:00<00:00, 133.42it/s]


train loss:  0.8038684725761414


100%|██████████| 19/19 [00:00<00:00, 172.77it/s]


train loss:  0.7981265783309937


100%|██████████| 19/19 [00:00<00:00, 203.59it/s]


train loss:  0.8031008243560791


100%|██████████| 19/19 [00:00<00:00, 227.25it/s]


train loss:  0.8097288012504578


100%|██████████| 19/19 [00:00<00:00, 218.08it/s]


train loss:  0.8057976365089417


100%|██████████| 19/19 [00:00<00:00, 209.79it/s]


train loss:  0.8112236857414246


100%|██████████| 19/19 [00:00<00:00, 212.59it/s]


train loss:  0.8029036521911621


100%|██████████| 19/19 [00:00<00:00, 228.67it/s]


train loss:  0.7965664267539978


100%|██████████| 19/19 [00:00<00:00, 227.19it/s]


train loss:  0.80283123254776


100%|██████████| 19/19 [00:00<00:00, 210.60it/s]


train loss:  0.7830541133880615


100%|██████████| 19/19 [00:00<00:00, 224.84it/s]


train loss:  0.8074546456336975


100%|██████████| 19/19 [00:00<00:00, 215.52it/s]


train loss:  0.7978759407997131


100%|██████████| 19/19 [00:00<00:00, 228.04it/s]


train loss:  0.8009403347969055


100%|██████████| 19/19 [00:00<00:00, 206.05it/s]


train loss:  0.7904435992240906


100%|██████████| 19/19 [00:00<00:00, 221.87it/s]


train loss:  0.7996140122413635


100%|██████████| 19/19 [00:00<00:00, 211.06it/s]


train loss:  0.7837348580360413


100%|██████████| 19/19 [00:00<00:00, 198.91it/s]


train loss:  0.7798994779586792


100%|██████████| 19/19 [00:00<00:00, 220.00it/s]


train loss:  0.7905794382095337


100%|██████████| 19/19 [00:00<00:00, 222.10it/s]


train loss:  0.7847304940223694


100%|██████████| 19/19 [00:00<00:00, 222.21it/s]


train loss:  0.8051149249076843


100%|██████████| 19/19 [00:00<00:00, 212.19it/s]


train loss:  0.812161922454834


100%|██████████| 19/19 [00:00<00:00, 214.96it/s]


train loss:  0.7832977771759033


100%|██████████| 19/19 [00:00<00:00, 225.24it/s]


train loss:  0.7817939519882202


100%|██████████| 19/19 [00:00<00:00, 224.99it/s]


train loss:  0.782818615436554


100%|██████████| 19/19 [00:00<00:00, 217.04it/s]


train loss:  0.7678654193878174


100%|██████████| 19/19 [00:00<00:00, 223.02it/s]


train loss:  0.7783231139183044


100%|██████████| 19/19 [00:00<00:00, 222.29it/s]


train loss:  0.7755208015441895


100%|██████████| 19/19 [00:00<00:00, 203.65it/s]


train loss:  0.7837098836898804


100%|██████████| 19/19 [00:00<00:00, 221.62it/s]


train loss:  0.781227707862854


100%|██████████| 19/19 [00:00<00:00, 218.94it/s]


train loss:  0.7618780136108398


100%|██████████| 19/19 [00:00<00:00, 210.19it/s]


train loss:  0.7671809792518616


100%|██████████| 19/19 [00:00<00:00, 217.89it/s]


train loss:  0.7650306224822998


100%|██████████| 19/19 [00:00<00:00, 225.09it/s]


train loss:  0.7769263386726379


100%|██████████| 19/19 [00:00<00:00, 215.49it/s]


train loss:  0.7721641063690186


100%|██████████| 19/19 [00:00<00:00, 218.12it/s]


train loss:  0.7823691964149475


100%|██████████| 19/19 [00:00<00:00, 227.23it/s]


train loss:  0.7762802839279175


100%|██████████| 19/19 [00:00<00:00, 224.28it/s]


train loss:  0.7871953845024109


100%|██████████| 19/19 [00:00<00:00, 210.36it/s]


train loss:  0.7861400246620178


100%|██████████| 19/19 [00:00<00:00, 212.16it/s]


train loss:  0.7719875574111938


100%|██████████| 19/19 [00:00<00:00, 223.46it/s]


train loss:  0.7583948373794556


100%|██████████| 19/19 [00:00<00:00, 226.62it/s]


train loss:  0.7579503059387207


100%|██████████| 19/19 [00:00<00:00, 216.81it/s]


train loss:  0.763834536075592


100%|██████████| 19/19 [00:00<00:00, 209.67it/s]


train loss:  0.7586743831634521


100%|██████████| 19/19 [00:00<00:00, 223.38it/s]


train loss:  0.7737228274345398


100%|██████████| 19/19 [00:00<00:00, 218.90it/s]


train loss:  0.7513471841812134


100%|██████████| 19/19 [00:00<00:00, 226.45it/s]


train loss:  0.7546184062957764


100%|██████████| 19/19 [00:00<00:00, 178.02it/s]


train loss:  0.7751370072364807


100%|██████████| 19/19 [00:00<00:00, 217.07it/s]


train loss:  0.7666595578193665


100%|██████████| 19/19 [00:00<00:00, 202.95it/s]


train loss:  0.7511370778083801


100%|██████████| 19/19 [00:00<00:00, 222.05it/s]


train loss:  0.7682417631149292


100%|██████████| 19/19 [00:00<00:00, 221.98it/s]


train loss:  0.7557744979858398


100%|██████████| 19/19 [00:00<00:00, 217.50it/s]


train loss:  0.7458633184432983


100%|██████████| 19/19 [00:00<00:00, 214.98it/s]


train loss:  0.7558987736701965


100%|██████████| 19/19 [00:00<00:00, 222.48it/s]


train loss:  0.7503679990768433


100%|██████████| 19/19 [00:00<00:00, 206.68it/s]


train loss:  0.7535408139228821


100%|██████████| 19/19 [00:00<00:00, 226.69it/s]


train loss:  0.7577786445617676


100%|██████████| 19/19 [00:00<00:00, 217.53it/s]


train loss:  0.7617608308792114


100%|██████████| 19/19 [00:00<00:00, 197.59it/s]


train loss:  0.7407759428024292


100%|██████████| 19/19 [00:00<00:00, 166.82it/s]


train loss:  0.7569417357444763


100%|██████████| 19/19 [00:00<00:00, 225.45it/s]


train loss:  0.7582703828811646


100%|██████████| 19/19 [00:00<00:00, 231.06it/s]


train loss:  0.7480170726776123


100%|██████████| 19/19 [00:00<00:00, 179.81it/s]


train loss:  0.7477741241455078


100%|██████████| 19/19 [00:00<00:00, 214.75it/s]


train loss:  0.7466886043548584


100%|██████████| 19/19 [00:00<00:00, 210.19it/s]


train loss:  0.7635102272033691


100%|██████████| 19/19 [00:00<00:00, 217.03it/s]


train loss:  0.7301159501075745


100%|██████████| 19/19 [00:00<00:00, 224.18it/s]


train loss:  0.7474160194396973


100%|██████████| 19/19 [00:00<00:00, 212.64it/s]


train loss:  0.7360349893569946


100%|██████████| 19/19 [00:00<00:00, 211.63it/s]


train loss:  0.7444347143173218


100%|██████████| 19/19 [00:00<00:00, 223.29it/s]


train loss:  0.7442812919616699


100%|██████████| 19/19 [00:00<00:00, 199.59it/s]


train loss:  0.7457561492919922


100%|██████████| 19/19 [00:00<00:00, 216.91it/s]


train loss:  0.7413837313652039
Testing...


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
100%|██████████| 7/7 [00:00<00:00, 199.49it/s]


test loss:  1.362087368965149
F1 (Macro):  {'f1': 0.2817870712526472}






# train loss:  0.7413837313652039
# test loss:  1.362087368965149
# F1 (Macro):  {'f1': 0.2817870712526472}

#  Generating text from a pre-trained decoder LM



## Step 1: Implement generate method
Implement the generate method that takes a language model and its tokenizer together with a list of text prefixes and outputs a list of generated sentences (one for each prefix). The prefixes should be "autocompleted" by the model.
Use the greedy-search method and process all prefixes as a single batch.


In [None]:
import torch
from typing import List

def generate(model, tokenizer, prefix: List[str], max_predicted: int) -> List[str]:
  """
  :param model: PreTrainedModel.
  :param tokenizer: PreTrainedTokenizer (https://huggingface.co/docs/transformers/main/main_classes/tokenizer#transformers.PreTrainedTokenizer)
  :prefix: Batch of prefixes to be autocompleted by the language model
  """
  if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
  pad_token_id = tokenizer.pad_token_id
  bos_token_id = tokenizer.bos_token_id
  eos_token_id = tokenizer.eos_token_id
  device = model.device

  # 1. Tokenize the prefixes and prepare them for input into the language model
  #    Notes
  #    - add the start token (BOS) but not the end token (EOS)
  #    - left! pad them to the maximum length in the batch, i.e. paddding is left to the "real" tokens
  #    - inputs_ids and attention_mask should be computes, and should already be stacked in the batch dim
  #    -> this can all be done by calling the tokenizer! See: https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__
  tokenizer.padding_side = 'left'
  # TODO: call the tokenizer to prepare the inputs (a dictionary)
  inputs = tokenizer(prefixes, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True)
  inputs = inputs.to(device=device)
  input_ids = inputs['input_ids']
  attention_mask = inputs['attention_mask']

  # Compute position ids based on the attention mask
  # Note: the position ids are the indices of the positions, starting with zero and increasing for each non-padding token
  # The positions ids for padding tokens can take any value
  position_ids: torch.LongTensor =  torch.arange(input_ids.shape[-1], device=device).expand(input_ids.shape)

  # Initialize some variables 
  N = position_ids.shape[0]
  past_key_values = None
  unfinished_sequences = torch.ones(N, dtype=torch.bool, device=device)
  predicted_input_ids = input_ids.clone()

  while True:
    # 2. Predict the next token logits by passing the previous inputs into the language model
    #    Note: already computed steps can be given by past_key_values, other steps are given as input_ids
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        return_dict=True
    )
    # (N x |V|)
    next_token_logits: torch.LongTensor = outputs.logits[:, -1, :]

    # 3. Select the next predicted tokens by greedy search
    # (N)
    next_tokens: torch.LongTensor = torch.argmax(next_token_logits, dim=-1)


    # 4. For all already finished sentences, replace the next_tokens by <PAD>
    next_tokens.masked_fill_(~unfinished_sequences, pad_token_id)

    # 5. Check which sentences are already finished by checking whether they contain <END>
    # Which sentences have been finished with the predicted token (next_tokens)
    # (N)
    newly_finished = next_tokens == eos_token_id

    # Which sentences, therefore, remain unfinished
    # (N)
    unfinished_sequences &= ~newly_finished

    # 6. Concatenate the next input to predicted_input_ids, attention_mask
    # (N x M_new) where M_new is one longer than the previous M of predicted_input_ids
    predicted_input_ids = torch.cat([predicted_input_ids, next_tokens.unsqueeze(-1)], dim=-1)

    # (N x M_new)
    attention_mask = torch.cat([attention_mask, torch.ones_like(next_tokens).unsqueeze(-1).to(device)], dim=-1)
    position_ids = position_ids.amax(dim=1, keepdim=True) + 1
    input_ids = next_tokens.unsqueeze(dim=1)
    past_key_values = outputs.past_key_values

    # 7. Check if finished 
    if predicted_input_ids.shape[1] >= max_predicted or unfinished_sequences.max() == 0:
      break

  # 8. Convert back to text 
  generated_sentences: List[str] = tokenizer.batch_decode(predicted_input_ids, skip_special_tokens=True)
  return generated_sentences




## Step 2: Load a decoder language model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

decoder_model_name = "gpt2-medium"

decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
decoder_model = AutoModelForCausalLM.from_pretrained(decoder_model_name)
decoder_model = decoder_model.to(device="cuda:0")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

## Step 3: Apply the generate method to some example text

In [None]:
# You can play around with different texts here
prefix_1 = "This sentences is about"
prefix_2 = "Can you complete this sentence?"
prefixes = [prefix_1, prefix_2]

generated_sentences = generate(decoder_model, decoder_tokenizer, prefixes, max_predicted=50)

for i, (prefix, sent) in enumerate(zip(prefixes, generated_sentences)):
  print(f'Completed sentence {i}: "{prefix}"')
  print("-------------------------------------------------------\n")
  print(sent)
  print("\n=======================================================\n\n")

Completed sentence 0: "This sentences is about"
-------------------------------------------------------

This sentences is about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about about



Completed sentence 1: "Can you complete this sentence?"
-------------------------------------------------------

Can you complete this sentence?

"I'm not sure if I can do it."

"I'm sorry, but I'm not sure if I can do it."

"I'm sorry, but I'm not sure if





In [None]:
# You can play around with different texts here
prefix_1 = "The Parthenon is"
prefix_2 = "Parallel programming can be useful in cases where"
prefix_3 = "Freud is regarded as the father of psychoanalysis."
prefixes = [prefix_1, prefix_2, prefix_3]
# For some reason, the function only succeeds in generating a coherent text for the last (or the first if we have only one prefix).
# The rest repeat either the last word of the prefix or a random one.

generated_sentences = generate(decoder_model, decoder_tokenizer, prefixes, max_predicted=50)

for i, (prefix, sent) in enumerate(zip(prefixes, generated_sentences)):
  print(f'Completed sentence {i+1}: "{prefix}"')
  print("-------------------------------------------------------\n")
  print(sent)
  print("\n=======================================================\n\n")

Completed sentence 1: "The Parthenon is"
-------------------------------------------------------

The Parthenon is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is



Completed sentence 2: "Parallel programming can be useful in cases where"
-------------------------------------------------------

Parallel programming can be useful in cases where he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he he



Completed sentence 3: "Freud is regarded as the father of psychoanalysis"
-------------------------------------------------------

Freud is regarded as the father of psychoanalysis. He was born in 1883 in Vienna, Austria. He studied at the University of Vienna and then at the University of Berlin. He was a member of the German Academy of Sciences and the German





As you can see, the results tend to be quite repetetive.
This is why greedy search is typically not used in practice.
Common alternatives include beam search and sampling from the distribution of next tokens. While this is not part of the mandatory exercise, you can try to implement other generation methods as part of a bonus task and see how it can improve the results.