In [15]:
from transformers import AutoConfig, AutoTokenizer, AutoModel, BertForSequenceClassification
from data.acronymDataset import AcronymDataset
from evaluate import load
import numpy as np
import torch

In [16]:
model_name = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=config.max_position_embeddings)
pre_trained_model = AutoModel.from_pretrained(model_name).to('mps')

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Test with the data

In [17]:
# %load_ext autoreload
# %autoreload 2
torch.manual_seed(5)
file_path = 'data/acronym_data.txt'
dataset = AcronymDataset(file_path=file_path, tokenizer=tokenizer)
data = dataset.data

[INFO] Dataset already been loaded, using the cached dataset..


In [18]:
dataset.preprocss_dataset()

                                                                   

In [19]:
train_loader, val_loader = dataset.get_dataloaders(train_size=0.9, batch_size=32)

In [20]:
batch = next(iter(train_loader)).to('mps')

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [21]:
# %load_ext autoreload
# %autoreload 2
from models.multiHeadModel import MultiHeadModel
from models.heads import ClassificationHead

In [12]:
in_features = config.hidden_size
two_labels_head = ClassificationHead(in_features=in_features, out_features=2).to('mps')
four_labels_head = ClassificationHead(in_features=in_features, out_features=4).to('mps')

classifiers = {
    "two_labels_head": two_labels_head,
    "four_labels_head": four_labels_head
}

In [13]:
multi_head_model = MultiHeadModel(pre_trained_model, classifiers)

In [None]:

with torch.no_grad():
    output = multi_head_model(batch, "two_labels_head")

output

In [41]:
metric = load("accuracy")
labels = batch['labels']
predictions = np.argmax(output.cpu().numpy(), axis=-1)
res = metric.compute(predictions=predictions, references=labels)

res

{'accuracy': 0.34375}

In [24]:
%load_ext autoreload
%autoreload 2

from utils.train import train
train_loader1, _ = dataset.get_dataloaders(train_size=0.9, batch_size=16)
train_loader2, _ = dataset.get_dataloaders(train_size=0.9, batch_size=32)

train_args = {
    "epochs": 1
}

heads_props = {
    "two_labels_head": {
        "train_loader": train_loader1
    },
    "four_labels_head": {
        "train_loader": train_loader2
    }
}

train(multi_head_model, heads_props, train_args)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([16, 334]) two_labels_head
torch.Size([32, 296]) four_labels_head
torch.Size([16, 278]) two_labels_head
torch.Size([32, 428]) four_labels_head
torch.Size([16, 308]) two_labels_head
torch.Size([32, 324]) four_labels_head
torch.Size([16, 266]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 310]) two_labels_head
torch.Size([32, 302]) four_labels_head
torch.Size([16, 329]) two_labels_head
torch.Size([32, 312]) four_labels_head
torch.Size([16, 325]) two_labels_head
torch.Size([32, 330]) four_labels_head
torch.Size([16, 283]) two_labels_head
torch.Size([32, 384]) four_labels_head
torch.Size([16, 262]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 257]) two_labels_head
torch.Size([32, 253]) four_labels_head
torch.Size([16, 261]) two_labels_head
torch.Size([32, 288]) four_labels_head




torch.Size([16, 411]) two_labels_head
torch.Size([32, 400]) four_labels_head
torch.Size([16, 269]) two_labels_head
torch.Size([32, 287]) four_labels_head




torch.Size([16, 284]) two_labels_head
torch.Size([32, 434]) four_labels_head
torch.Size([16, 405]) two_labels_head
torch.Size([32, 325]) four_labels_head
torch.Size([16, 348]) two_labels_head
torch.Size([32, 406]) four_labels_head
torch.Size([16, 390]) two_labels_head
torch.Size([32, 274]) four_labels_head
torch.Size([16, 325]) two_labels_head
torch.Size([32, 267]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 296]) four_labels_head
torch.Size([16, 293]) two_labels_head
torch.Size([32, 352]) four_labels_head
torch.Size([16, 494]) two_labels_head
torch.Size([32, 310]) four_labels_head
torch.Size([16, 289]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 275]) two_labels_head
torch.Size([32, 420]) four_labels_head
torch.Size([16, 355]) two_labels_head
torch.Size([32, 322]) four_labels_head
torch.Size([16, 378]) two_labels_head
torch.Size([32, 326]) four_labels_head
torch.Size([16, 312]) two_labels_head
torch.Size([32, 371]) four_labels_head



torch.Size([16, 332]) two_labels_head
torch.Size([32, 290]) four_labels_head
torch.Size([16, 291]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 310]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 341]) two_labels_head
torch.Size([32, 353]) four_labels_head
torch.Size([16, 367]) two_labels_head
torch.Size([32, 330]) four_labels_head
torch.Size([16, 417]) two_labels_head
torch.Size([32, 351]) four_labels_head
torch.Size([16, 262]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 365]) two_labels_head
torch.Size([32, 283]) four_labels_head
torch.Size([16, 377]) two_labels_head
torch.Size([32, 336]) four_labels_head
torch.Size([16, 308]) two_labels_head
torch.Size([32, 268]) four_labels_head
torch.Size([16, 322]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 343]) two_labels_head
torch.Size([32, 304]) four_labels_head



torch.Size([16, 284]) two_labels_head
torch.Size([32, 287]) four_labels_head
torch.Size([16, 327]) two_labels_head
torch.Size([32, 403]) four_labels_head
torch.Size([16, 418]) two_labels_head
torch.Size([32, 332]) four_labels_head
torch.Size([16, 271]) two_labels_head
torch.Size([32, 424]) four_labels_head
torch.Size([16, 322]) two_labels_head
torch.Size([32, 359]) four_labels_head
torch.Size([16, 326]) two_labels_head
torch.Size([32, 471]) four_labels_head
torch.Size([16, 290]) two_labels_head
torch.Size([32, 476]) four_labels_head
torch.Size([16, 260]) two_labels_head
torch.Size([32, 355]) four_labels_head
torch.Size([16, 336]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 480]) two_labels_head
torch.Size([32, 327]) four_labels_head
torch.Size([16, 288]) two_labels_head
torch.Size([32, 367]) four_labels_head
torch.Size([16, 362]) two_labels_head
torch.Size([32, 320]) four_labels_head
torch.Size([16, 241]) two_labels_head
torch.Size([32, 512]) four_labels_head



torch.Size([16, 309]) two_labels_head
torch.Size([32, 353]) four_labels_head
torch.Size([16, 314]) two_labels_head
torch.Size([32, 323]) four_labels_head
torch.Size([16, 407]) two_labels_head
torch.Size([32, 388]) four_labels_head
torch.Size([16, 278]) two_labels_head
torch.Size([32, 371]) four_labels_head
torch.Size([16, 306]) two_labels_head
torch.Size([32, 309]) four_labels_head
torch.Size([16, 343]) two_labels_head
torch.Size([32, 277]) four_labels_head
torch.Size([16, 274]) two_labels_head
torch.Size([32, 253]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 312]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 288]) two_labels_head
torch.Size([32, 276]) four_labels_head
torch.Size([16, 279]) two_labels_head
torch.Size([32, 485]) four_labels_head
torch.Size([16, 289]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 295]) two_labels_head
torch.Size([32, 364]) four_labels_head



torch.Size([16, 304]) two_labels_head
torch.Size([32, 314]) four_labels_head
torch.Size([16, 255]) two_labels_head
torch.Size([32, 318]) four_labels_head
torch.Size([16, 313]) two_labels_head
torch.Size([32, 280]) four_labels_head
torch.Size([16, 303]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 379]) two_labels_head
torch.Size([32, 303]) four_labels_head
torch.Size([16, 274]) two_labels_head
torch.Size([32, 392]) four_labels_head
torch.Size([16, 318]) two_labels_head
torch.Size([32, 417]) four_labels_head
torch.Size([16, 304]) two_labels_head
torch.Size([32, 303]) four_labels_head
torch.Size([16, 365]) two_labels_head
torch.Size([32, 269]) four_labels_head
torch.Size([16, 259]) two_labels_head
torch.Size([32, 356]) four_labels_head
torch.Size([16, 379]) two_labels_head
torch.Size([32, 281]) four_labels_head
torch.Size([16, 262]) two_labels_head
torch.Size([32, 414]) four_labels_head
torch.Size([16, 307]) two_labels_head
torch.Size([32, 313]) four_labels_head



torch.Size([16, 273]) two_labels_head
torch.Size([32, 397]) four_labels_head
torch.Size([16, 279]) two_labels_head
torch.Size([32, 395]) four_labels_head
torch.Size([16, 327]) two_labels_head
torch.Size([32, 500]) four_labels_head
torch.Size([16, 292]) two_labels_head
torch.Size([32, 356]) four_labels_head
torch.Size([16, 288]) two_labels_head
torch.Size([32, 430]) four_labels_head
torch.Size([16, 428]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 299]) two_labels_head
torch.Size([32, 330]) four_labels_head
torch.Size([16, 403]) two_labels_head
torch.Size([32, 343]) four_labels_head
torch.Size([16, 309]) two_labels_head
torch.Size([32, 430]) four_labels_head
torch.Size([16, 273]) two_labels_head
torch.Size([32, 407]) four_labels_head
torch.Size([16, 266]) two_labels_head
torch.Size([32, 354]) four_labels_head
torch.Size([16, 292]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 250]) two_labels_head
torch.Size([32, 314]) four_labels_head



torch.Size([16, 322]) two_labels_head
torch.Size([32, 293]) four_labels_head
torch.Size([16, 271]) two_labels_head
torch.Size([32, 308]) four_labels_head
torch.Size([16, 283]) two_labels_head
torch.Size([32, 372]) four_labels_head
torch.Size([16, 243]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 382]) two_labels_head
torch.Size([32, 313]) four_labels_head
torch.Size([16, 308]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 282]) two_labels_head
torch.Size([32, 276]) four_labels_head
torch.Size([16, 281]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 445]) two_labels_head
torch.Size([32, 288]) four_labels_head
torch.Size([16, 298]) two_labels_head
torch.Size([32, 286]) four_labels_head
torch.Size([16, 290]) two_labels_head
torch.Size([32, 367]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 389]) four_labels_head
torch.Size([16, 315]) two_labels_head
torch.Size([32, 300]) four_labels_head



torch.Size([16, 342]) two_labels_head
torch.Size([32, 369]) four_labels_head
torch.Size([16, 302]) two_labels_head
torch.Size([32, 318]) four_labels_head
torch.Size([16, 348]) two_labels_head
torch.Size([32, 373]) four_labels_head
torch.Size([16, 344]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 357]) two_labels_head
torch.Size([32, 297]) four_labels_head
torch.Size([16, 402]) two_labels_head
torch.Size([32, 410]) four_labels_head
torch.Size([16, 248]) two_labels_head
torch.Size([32, 380]) four_labels_head
torch.Size([16, 319]) two_labels_head
torch.Size([32, 338]) four_labels_head
torch.Size([16, 271]) two_labels_head
torch.Size([32, 307]) four_labels_head
torch.Size([16, 385]) two_labels_head
torch.Size([32, 289]) four_labels_head
torch.Size([16, 240]) two_labels_head
torch.Size([32, 441]) four_labels_head
torch.Size([16, 391]) two_labels_head
torch.Size([32, 310]) four_labels_head
torch.Size([16, 431]) two_labels_head
torch.Size([32, 367]) four_labels_head



torch.Size([16, 315]) two_labels_head
torch.Size([32, 351]) four_labels_head
torch.Size([16, 242]) two_labels_head
torch.Size([32, 494]) four_labels_head
torch.Size([16, 354]) two_labels_head
torch.Size([32, 365]) four_labels_head
torch.Size([16, 337]) two_labels_head
torch.Size([32, 308]) four_labels_head
torch.Size([16, 334]) two_labels_head
torch.Size([32, 272]) four_labels_head
torch.Size([16, 285]) two_labels_head
torch.Size([32, 318]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 267]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 267]) two_labels_head
torch.Size([32, 332]) four_labels_head
torch.Size([16, 342]) two_labels_head
torch.Size([32, 294]) four_labels_head
torch.Size([16, 287]) two_labels_head
torch.Size([32, 370]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 372]) four_labels_head
torch.Size([16, 307]) two_labels_head
torch.Size([32, 312]) four_labels_head



torch.Size([16, 394]) two_labels_head
torch.Size([32, 317]) four_labels_head
torch.Size([16, 432]) two_labels_head
torch.Size([32, 425]) four_labels_head
torch.Size([16, 373]) two_labels_head
torch.Size([32, 295]) four_labels_head
torch.Size([16, 325]) two_labels_head
torch.Size([32, 373]) four_labels_head
torch.Size([16, 278]) two_labels_head
torch.Size([32, 327]) four_labels_head
torch.Size([16, 304]) two_labels_head
torch.Size([32, 312]) four_labels_head
torch.Size([16, 348]) two_labels_head
torch.Size([32, 340]) four_labels_head
torch.Size([16, 300]) two_labels_head
torch.Size([32, 412]) four_labels_head
torch.Size([16, 324]) two_labels_head
torch.Size([32, 328]) four_labels_head
torch.Size([16, 438]) two_labels_head
torch.Size([32, 325]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 465]) four_labels_head
torch.Size([16, 274]) two_labels_head
torch.Size([32, 308]) four_labels_head
torch.Size([16, 302]) two_labels_head
torch.Size([32, 501]) four_labels_head



torch.Size([16, 282]) two_labels_head
torch.Size([32, 327]) four_labels_head
torch.Size([16, 351]) two_labels_head
torch.Size([32, 312]) four_labels_head
torch.Size([16, 396]) two_labels_head
torch.Size([32, 328]) four_labels_head
torch.Size([16, 388]) two_labels_head
torch.Size([32, 298]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 429]) four_labels_head
torch.Size([16, 303]) two_labels_head
torch.Size([32, 408]) four_labels_head
torch.Size([16, 353]) two_labels_head
torch.Size([32, 318]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 329]) four_labels_head
torch.Size([16, 267]) two_labels_head
torch.Size([32, 306]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 400]) four_labels_head
torch.Size([16, 274]) two_labels_head
torch.Size([32, 490]) four_labels_head
torch.Size([16, 303]) two_labels_head
torch.Size([32, 347]) four_labels_head
torch.Size([16, 344]) two_labels_head
torch.Size([32, 328]) four_labels_head



torch.Size([16, 512]) two_labels_head
torch.Size([32, 335]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 339]) four_labels_head
torch.Size([16, 399]) two_labels_head
torch.Size([32, 379]) four_labels_head
torch.Size([16, 321]) two_labels_head
torch.Size([32, 406]) four_labels_head
torch.Size([16, 325]) two_labels_head
torch.Size([32, 414]) four_labels_head
torch.Size([16, 378]) two_labels_head
torch.Size([32, 415]) four_labels_head
torch.Size([16, 364]) two_labels_head
torch.Size([32, 355]) four_labels_head
torch.Size([16, 327]) two_labels_head
torch.Size([32, 312]) four_labels_head
torch.Size([16, 263]) two_labels_head
torch.Size([32, 296]) four_labels_head
torch.Size([16, 379]) two_labels_head
torch.Size([32, 350]) four_labels_head
torch.Size([16, 275]) two_labels_head
torch.Size([32, 306]) four_labels_head
torch.Size([16, 332]) two_labels_head
torch.Size([32, 307]) four_labels_head




torch.Size([16, 369]) two_labels_head
torch.Size([32, 301]) four_labels_head
torch.Size([16, 321]) two_labels_head
torch.Size([32, 326]) four_labels_head
torch.Size([16, 249]) two_labels_head
torch.Size([32, 388]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 314]) four_labels_head
torch.Size([16, 297]) two_labels_head
torch.Size([32, 333]) four_labels_head
torch.Size([16, 307]) two_labels_head
torch.Size([32, 396]) four_labels_head
torch.Size([16, 384]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 304]) two_labels_head
torch.Size([32, 311]) four_labels_head
torch.Size([16, 379]) two_labels_head
torch.Size([32, 286]) four_labels_head
torch.Size([16, 342]) two_labels_head
torch.Size([32, 304]) four_labels_head
torch.Size([16, 277]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 427]) two_labels_head
torch.Size([32, 342]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 306]) four_labels_head



torch.Size([16, 287]) two_labels_head
torch.Size([32, 427]) four_labels_head
torch.Size([16, 281]) two_labels_head
torch.Size([32, 297]) four_labels_head
torch.Size([16, 262]) two_labels_head
torch.Size([32, 271]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 293]) four_labels_head
torch.Size([16, 380]) two_labels_head
torch.Size([32, 471]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 343]) four_labels_head
torch.Size([16, 295]) two_labels_head
torch.Size([32, 349]) four_labels_head
torch.Size([16, 258]) two_labels_head
torch.Size([32, 271]) four_labels_head
torch.Size([16, 425]) two_labels_head
torch.Size([32, 358]) four_labels_head
torch.Size([16, 418]) two_labels_head
torch.Size([32, 334]) four_labels_head
torch.Size([16, 293]) two_labels_head
torch.Size([32, 302]) four_labels_head
torch.Size([16, 307]) two_labels_head
torch.Size([32, 352]) four_labels_head
torch.Size([16, 262]) two_labels_head
torch.Size([32, 352]) four_labels_head



torch.Size([16, 264]) two_labels_head
torch.Size([32, 336]) four_labels_head
torch.Size([16, 274]) two_labels_head
torch.Size([32, 434]) four_labels_head
torch.Size([16, 396]) two_labels_head
torch.Size([32, 457]) four_labels_head
torch.Size([16, 336]) two_labels_head
torch.Size([32, 333]) four_labels_head
torch.Size([16, 298]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 331]) two_labels_head
torch.Size([32, 283]) four_labels_head
torch.Size([16, 384]) two_labels_head
torch.Size([32, 298]) four_labels_head
torch.Size([16, 230]) two_labels_head
torch.Size([32, 479]) four_labels_head
torch.Size([16, 455]) two_labels_head
torch.Size([32, 368]) four_labels_head
torch.Size([16, 416]) two_labels_head
torch.Size([32, 348]) four_labels_head
torch.Size([16, 356]) two_labels_head
torch.Size([32, 438]) four_labels_head
torch.Size([16, 254]) two_labels_head
torch.Size([32, 495]) four_labels_head
torch.Size([16, 448]) two_labels_head
torch.Size([32, 503]) four_labels_head



torch.Size([16, 286]) two_labels_head
torch.Size([32, 370]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 338]) four_labels_head
torch.Size([16, 288]) two_labels_head
torch.Size([32, 310]) four_labels_head
torch.Size([16, 299]) two_labels_head
torch.Size([32, 314]) four_labels_head
torch.Size([16, 411]) two_labels_head
torch.Size([32, 396]) four_labels_head
torch.Size([16, 275]) two_labels_head
torch.Size([32, 342]) four_labels_head
torch.Size([16, 328]) two_labels_head
torch.Size([32, 433]) four_labels_head
torch.Size([16, 351]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 389]) two_labels_head
torch.Size([32, 363]) four_labels_head
torch.Size([16, 398]) two_labels_head
torch.Size([32, 346]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 302]) four_labels_head
torch.Size([16, 249]) two_labels_head
torch.Size([32, 250]) four_labels_head
torch.Size([16, 267]) two_labels_head
torch.Size([32, 273]) four_labels_head



torch.Size([16, 280]) two_labels_head
torch.Size([32, 307]) four_labels_head
torch.Size([16, 243]) two_labels_head
torch.Size([32, 343]) four_labels_head
torch.Size([16, 321]) two_labels_head
torch.Size([32, 339]) four_labels_head
torch.Size([16, 288]) two_labels_head
torch.Size([32, 322]) four_labels_head
torch.Size([16, 268]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 286]) two_labels_head
torch.Size([32, 422]) four_labels_head
torch.Size([16, 261]) two_labels_head
torch.Size([32, 293]) four_labels_head
torch.Size([16, 358]) two_labels_head
torch.Size([32, 334]) four_labels_head
torch.Size([16, 312]) two_labels_head
torch.Size([32, 362]) four_labels_head
torch.Size([16, 367]) two_labels_head
torch.Size([32, 304]) four_labels_head
torch.Size([16, 308]) two_labels_head
torch.Size([32, 296]) four_labels_head
torch.Size([16, 231]) two_labels_head
torch.Size([32, 271]) four_labels_head
torch.Size([16, 315]) two_labels_head
torch.Size([32, 318]) four_labels_head



torch.Size([16, 378]) two_labels_head
torch.Size([32, 353]) four_labels_head
torch.Size([16, 408]) two_labels_head
torch.Size([32, 323]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 306]) four_labels_head
torch.Size([16, 314]) two_labels_head
torch.Size([32, 417]) four_labels_head
torch.Size([16, 292]) two_labels_head
torch.Size([32, 374]) four_labels_head
torch.Size([16, 278]) two_labels_head
torch.Size([32, 381]) four_labels_head
torch.Size([16, 340]) two_labels_head
torch.Size([32, 312]) four_labels_head
torch.Size([16, 286]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 272]) two_labels_head
torch.Size([32, 336]) four_labels_head
torch.Size([16, 316]) two_labels_head
torch.Size([32, 263]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 292]) four_labels_head
torch.Size([16, 307]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 303]) two_labels_head
torch.Size([32, 327]) four_labels_head



torch.Size([16, 255]) two_labels_head
torch.Size([32, 303]) four_labels_head
torch.Size([16, 244]) two_labels_head
torch.Size([32, 283]) four_labels_head
torch.Size([16, 338]) two_labels_head
torch.Size([32, 411]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 304]) four_labels_head
torch.Size([16, 254]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 378]) two_labels_head
torch.Size([32, 270]) four_labels_head
torch.Size([16, 284]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 217]) two_labels_head
torch.Size([32, 427]) four_labels_head
torch.Size([16, 228]) two_labels_head
torch.Size([32, 344]) four_labels_head
torch.Size([16, 457]) two_labels_head
torch.Size([32, 306]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 270]) four_labels_head
torch.Size([16, 284]) two_labels_head
torch.Size([32, 315]) four_labels_head
torch.Size([16, 250]) two_labels_head
torch.Size([32, 325]) four_labels_head



torch.Size([16, 477]) two_labels_head
torch.Size([32, 316]) four_labels_head
torch.Size([16, 296]) two_labels_head
torch.Size([32, 373]) four_labels_head
torch.Size([16, 361]) two_labels_head
torch.Size([32, 295]) four_labels_head
torch.Size([16, 283]) two_labels_head
torch.Size([32, 285]) four_labels_head
torch.Size([16, 252]) two_labels_head
torch.Size([32, 374]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 439]) four_labels_head
torch.Size([16, 324]) two_labels_head
torch.Size([32, 340]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 292]) four_labels_head
torch.Size([16, 308]) two_labels_head
torch.Size([32, 314]) four_labels_head
torch.Size([16, 417]) two_labels_head
torch.Size([32, 298]) four_labels_head
torch.Size([16, 376]) two_labels_head
torch.Size([32, 320]) four_labels_head
torch.Size([16, 281]) two_labels_head
torch.Size([32, 370]) four_labels_head
torch.Size([16, 279]) two_labels_head
torch.Size([32, 394]) four_labels_head



torch.Size([16, 265]) two_labels_head
torch.Size([32, 340]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 286]) four_labels_head
torch.Size([16, 336]) two_labels_head
torch.Size([32, 303]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 296]) four_labels_head
torch.Size([16, 296]) two_labels_head
torch.Size([32, 293]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 382]) four_labels_head
torch.Size([16, 324]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 241]) two_labels_head
torch.Size([32, 286]) four_labels_head
torch.Size([16, 286]) two_labels_head
torch.Size([32, 291]) four_labels_head
torch.Size([16, 355]) two_labels_head
torch.Size([32, 293]) four_labels_head
torch.Size([16, 249]) two_labels_head
torch.Size([32, 320]) four_labels_head
torch.Size([16, 394]) two_labels_head
torch.Size([32, 327]) four_labels_head
torch.Size([16, 275]) two_labels_head
torch.Size([32, 326]) four_labels_head



torch.Size([16, 371]) two_labels_head
torch.Size([32, 339]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 433]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 266]) four_labels_head
torch.Size([16, 376]) two_labels_head
torch.Size([32, 354]) four_labels_head
torch.Size([16, 223]) two_labels_head
torch.Size([32, 255]) four_labels_head
torch.Size([16, 338]) two_labels_head
torch.Size([32, 293]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 338]) four_labels_head
torch.Size([16, 304]) two_labels_head
torch.Size([32, 418]) four_labels_head
torch.Size([16, 275]) two_labels_head
torch.Size([32, 315]) four_labels_head
torch.Size([16, 367]) two_labels_head
torch.Size([32, 393]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 343]) four_labels_head
torch.Size([16, 248]) two_labels_head
torch.Size([32, 278]) four_labels_head
torch.Size([16, 268]) two_labels_head
torch.Size([32, 291]) four_labels_head



torch.Size([16, 340]) two_labels_head
torch.Size([32, 410]) four_labels_head
torch.Size([16, 258]) two_labels_head
torch.Size([32, 276]) four_labels_head
torch.Size([16, 443]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 305]) two_labels_head
torch.Size([32, 498]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 475]) four_labels_head
torch.Size([16, 269]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 307]) two_labels_head
torch.Size([32, 344]) four_labels_head
torch.Size([16, 287]) two_labels_head
torch.Size([32, 349]) four_labels_head
torch.Size([16, 401]) two_labels_head
torch.Size([32, 380]) four_labels_head
torch.Size([16, 335]) two_labels_head
torch.Size([32, 338]) four_labels_head
torch.Size([16, 451]) two_labels_head
torch.Size([32, 325]) four_labels_head
torch.Size([16, 428]) two_labels_head
torch.Size([32, 324]) four_labels_head
torch.Size([16, 366]) two_labels_head
torch.Size([32, 512]) four_labels_head



torch.Size([16, 512]) two_labels_head
torch.Size([32, 380]) four_labels_head
torch.Size([16, 306]) two_labels_head
torch.Size([32, 460]) four_labels_head
torch.Size([16, 271]) two_labels_head
torch.Size([32, 340]) four_labels_head
torch.Size([16, 234]) two_labels_head
torch.Size([32, 361]) four_labels_head
torch.Size([16, 284]) two_labels_head
torch.Size([32, 386]) four_labels_head
torch.Size([16, 260]) two_labels_head
torch.Size([32, 317]) four_labels_head
torch.Size([16, 300]) two_labels_head
torch.Size([32, 332]) four_labels_head
torch.Size([16, 400]) two_labels_head
torch.Size([32, 420]) four_labels_head
torch.Size([16, 353]) two_labels_head
torch.Size([32, 295]) four_labels_head
torch.Size([16, 286]) two_labels_head
torch.Size([32, 354]) four_labels_head
torch.Size([16, 266]) two_labels_head
torch.Size([32, 441]) four_labels_head
torch.Size([16, 335]) two_labels_head
torch.Size([32, 302]) four_labels_head
torch.Size([16, 290]) two_labels_head
torch.Size([32, 293]) four_labels_head



torch.Size([16, 337]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 336]) two_labels_head
torch.Size([32, 373]) four_labels_head
torch.Size([16, 269]) two_labels_head
torch.Size([32, 378]) four_labels_head
torch.Size([16, 282]) two_labels_head
torch.Size([32, 322]) four_labels_head
torch.Size([16, 252]) two_labels_head
torch.Size([32, 338]) four_labels_head
torch.Size([16, 259]) two_labels_head
torch.Size([32, 464]) four_labels_head
torch.Size([16, 423]) two_labels_head
torch.Size([32, 308]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 314]) four_labels_head
torch.Size([16, 261]) two_labels_head
torch.Size([32, 367]) four_labels_head
torch.Size([16, 298]) two_labels_head
torch.Size([32, 348]) four_labels_head
torch.Size([16, 259]) two_labels_head
torch.Size([32, 347]) four_labels_head
torch.Size([16, 296]) two_labels_head
torch.Size([32, 291]) four_labels_head
torch.Size([16, 215]) two_labels_head
torch.Size([32, 311]) four_labels_head



torch.Size([16, 356]) two_labels_head
torch.Size([32, 282]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 315]) four_labels_head
torch.Size([16, 282]) two_labels_head
torch.Size([32, 381]) four_labels_head
torch.Size([16, 311]) two_labels_head
torch.Size([32, 276]) four_labels_head
torch.Size([16, 281]) two_labels_head
torch.Size([32, 351]) four_labels_head
torch.Size([16, 280]) two_labels_head
torch.Size([32, 347]) four_labels_head
torch.Size([16, 306]) two_labels_head
torch.Size([32, 288]) four_labels_head
torch.Size([16, 342]) two_labels_head
torch.Size([32, 294]) four_labels_head
torch.Size([16, 269]) two_labels_head
torch.Size([32, 385]) four_labels_head
torch.Size([16, 302]) two_labels_head
torch.Size([32, 339]) four_labels_head
torch.Size([16, 428]) two_labels_head
torch.Size([32, 310]) four_labels_head
torch.Size([16, 461]) two_labels_head
torch.Size([32, 394]) four_labels_head
torch.Size([16, 258]) two_labels_head
torch.Size([32, 363]) four_labels_head



torch.Size([16, 395]) two_labels_head
torch.Size([32, 290]) four_labels_head
torch.Size([16, 449]) two_labels_head
torch.Size([32, 335]) four_labels_head
torch.Size([16, 273]) two_labels_head
torch.Size([32, 445]) four_labels_head
torch.Size([16, 358]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 395]) two_labels_head
torch.Size([32, 308]) four_labels_head
torch.Size([16, 460]) two_labels_head
torch.Size([32, 281]) four_labels_head
torch.Size([16, 271]) two_labels_head
torch.Size([32, 395]) four_labels_head
torch.Size([16, 433]) two_labels_head
torch.Size([32, 333]) four_labels_head
torch.Size([16, 260]) two_labels_head
torch.Size([32, 350]) four_labels_head
torch.Size([16, 328]) two_labels_head
torch.Size([32, 375]) four_labels_head
torch.Size([16, 349]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 464]) two_labels_head
torch.Size([32, 342]) four_labels_head
torch.Size([16, 303]) two_labels_head
torch.Size([32, 321]) four_labels_head



torch.Size([16, 384]) two_labels_head
torch.Size([32, 269]) four_labels_head
torch.Size([16, 356]) two_labels_head
torch.Size([32, 415]) four_labels_head
torch.Size([16, 261]) two_labels_head
torch.Size([32, 286]) four_labels_head
torch.Size([16, 354]) two_labels_head
torch.Size([32, 501]) four_labels_head
torch.Size([16, 290]) two_labels_head
torch.Size([32, 427]) four_labels_head
torch.Size([16, 278]) two_labels_head
torch.Size([32, 291]) four_labels_head
torch.Size([16, 328]) two_labels_head
torch.Size([32, 335]) four_labels_head
torch.Size([16, 285]) two_labels_head
torch.Size([32, 321]) four_labels_head
torch.Size([16, 239]) two_labels_head
torch.Size([32, 309]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 298]) four_labels_head
torch.Size([16, 241]) two_labels_head
torch.Size([32, 327]) four_labels_head
torch.Size([16, 278]) two_labels_head
torch.Size([32, 356]) four_labels_head
torch.Size([16, 292]) two_labels_head
torch.Size([32, 274]) four_labels_head



torch.Size([16, 359]) two_labels_head
torch.Size([32, 304]) four_labels_head
torch.Size([16, 274]) two_labels_head
torch.Size([32, 362]) four_labels_head
torch.Size([16, 340]) two_labels_head
torch.Size([32, 386]) four_labels_head
torch.Size([16, 379]) two_labels_head
torch.Size([32, 304]) four_labels_head
torch.Size([16, 306]) two_labels_head
torch.Size([32, 354]) four_labels_head
torch.Size([16, 270]) two_labels_head
torch.Size([32, 329]) four_labels_head
torch.Size([16, 365]) two_labels_head
torch.Size([32, 348]) four_labels_head
torch.Size([16, 343]) two_labels_head
torch.Size([32, 350]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 314]) four_labels_head
torch.Size([16, 333]) two_labels_head
torch.Size([32, 287]) four_labels_head
torch.Size([16, 310]) two_labels_head
torch.Size([32, 286]) four_labels_head
torch.Size([16, 500]) two_labels_head
torch.Size([32, 376]) four_labels_head
torch.Size([16, 372]) two_labels_head
torch.Size([32, 291]) four_labels_head



torch.Size([16, 307]) two_labels_head
torch.Size([32, 324]) four_labels_head
torch.Size([16, 275]) two_labels_head
torch.Size([32, 318]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 333]) four_labels_head
torch.Size([16, 273]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 241]) two_labels_head
torch.Size([32, 342]) four_labels_head
torch.Size([16, 293]) two_labels_head
torch.Size([32, 243]) four_labels_head
torch.Size([16, 271]) two_labels_head
torch.Size([32, 306]) four_labels_head
torch.Size([16, 271]) two_labels_head
torch.Size([32, 340]) four_labels_head
torch.Size([16, 342]) two_labels_head
torch.Size([32, 345]) four_labels_head
torch.Size([16, 333]) two_labels_head
torch.Size([32, 340]) four_labels_head
torch.Size([16, 401]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 373]) two_labels_head
torch.Size([32, 416]) four_labels_head
torch.Size([16, 330]) two_labels_head
torch.Size([32, 410]) four_labels_head



torch.Size([16, 367]) two_labels_head
torch.Size([32, 422]) four_labels_head
torch.Size([16, 356]) two_labels_head
torch.Size([32, 484]) four_labels_head
torch.Size([16, 346]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 326]) two_labels_head
torch.Size([32, 311]) four_labels_head
torch.Size([16, 292]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 308]) two_labels_head
torch.Size([32, 428]) four_labels_head
torch.Size([16, 231]) two_labels_head
torch.Size([32, 463]) four_labels_head
torch.Size([16, 286]) two_labels_head
torch.Size([32, 346]) four_labels_head
torch.Size([16, 269]) two_labels_head
torch.Size([32, 356]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 386]) four_labels_head
torch.Size([16, 351]) two_labels_head
torch.Size([32, 272]) four_labels_head
torch.Size([16, 401]) two_labels_head
torch.Size([32, 347]) four_labels_head
torch.Size([16, 303]) two_labels_head
torch.Size([32, 347]) four_labels_head



torch.Size([16, 285]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 344]) two_labels_head
torch.Size([32, 453]) four_labels_head
torch.Size([16, 302]) two_labels_head
torch.Size([32, 395]) four_labels_head
torch.Size([16, 392]) two_labels_head
torch.Size([32, 258]) four_labels_head
torch.Size([16, 280]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 296]) two_labels_head
torch.Size([32, 377]) four_labels_head
torch.Size([16, 285]) two_labels_head
torch.Size([32, 411]) four_labels_head
torch.Size([16, 391]) two_labels_head
torch.Size([32, 291]) four_labels_head
torch.Size([16, 253]) two_labels_head
torch.Size([32, 339]) four_labels_head
torch.Size([16, 347]) two_labels_head
torch.Size([32, 372]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 327]) four_labels_head
torch.Size([16, 260]) two_labels_head
torch.Size([32, 411]) four_labels_head
torch.Size([16, 455]) two_labels_head
torch.Size([32, 410]) four_labels_head




torch.Size([16, 280]) two_labels_head
torch.Size([32, 338]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 396]) four_labels_head
torch.Size([16, 321]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 284]) two_labels_head
torch.Size([32, 314]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 311]) four_labels_head
torch.Size([16, 315]) two_labels_head
torch.Size([32, 287]) four_labels_head
torch.Size([16, 361]) two_labels_head
torch.Size([32, 315]) four_labels_head
torch.Size([16, 388]) two_labels_head
torch.Size([32, 402]) four_labels_head
torch.Size([16, 231]) two_labels_head
torch.Size([32, 382]) four_labels_head
torch.Size([16, 436]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 371]) two_labels_head
torch.Size([32, 303]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 377]) four_labels_head


465it [00:07, 59.74it/s][A

torch.Size([16, 364]) two_labels_head
torch.Size([32, 283]) four_labels_head
torch.Size([16, 347]) two_labels_head
torch.Size([32, 320]) four_labels_head
torch.Size([16, 237]) two_labels_head
torch.Size([32, 403]) four_labels_head
torch.Size([16, 333]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 351]) two_labels_head
torch.Size([32, 355]) four_labels_head
torch.Size([16, 274]) two_labels_head
torch.Size([32, 312]) four_labels_head
torch.Size([16, 258]) two_labels_head
torch.Size([32, 374]) four_labels_head
torch.Size([16, 228]) two_labels_head
torch.Size([32, 394]) four_labels_head
torch.Size([16, 317]) two_labels_head
torch.Size([32, 300]) four_labels_head
torch.Size([16, 320]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 234]) two_labels_head
torch.Size([32, 330]) four_labels_head
torch.Size([16, 321]) two_labels_head
torch.Size([32, 361]) four_labels_head
torch.Size([16, 450]) two_labels_head
torch.Size([32, 294]) four_labels_head



torch.Size([16, 293]) two_labels_head
torch.Size([32, 439]) four_labels_head
torch.Size([16, 250]) two_labels_head
torch.Size([32, 278]) four_labels_head
torch.Size([16, 377]) two_labels_head
torch.Size([32, 283]) four_labels_head
torch.Size([16, 284]) two_labels_head
torch.Size([32, 266]) four_labels_head
torch.Size([16, 397]) two_labels_head
torch.Size([32, 288]) four_labels_head
torch.Size([16, 266]) two_labels_head
torch.Size([32, 273]) four_labels_head
torch.Size([16, 287]) two_labels_head
torch.Size([32, 351]) four_labels_head
torch.Size([16, 331]) two_labels_head
torch.Size([32, 384]) four_labels_head
torch.Size([16, 262]) two_labels_head
torch.Size([32, 301]) four_labels_head
torch.Size([16, 295]) two_labels_head
torch.Size([32, 326]) four_labels_head
torch.Size([16, 266]) two_labels_head
torch.Size([32, 452]) four_labels_head
torch.Size([16, 245]) two_labels_head
torch.Size([32, 327]) four_labels_head
torch.Size([16, 322]) two_labels_head
torch.Size([32, 244]) four_labels_head



torch.Size([16, 287]) two_labels_head
torch.Size([32, 345]) four_labels_head
torch.Size([16, 296]) two_labels_head
torch.Size([32, 405]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 365]) four_labels_head
torch.Size([16, 329]) two_labels_head
torch.Size([32, 352]) four_labels_head
torch.Size([16, 234]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 277]) two_labels_head
torch.Size([32, 343]) four_labels_head
torch.Size([16, 358]) two_labels_head
torch.Size([32, 301]) four_labels_head
torch.Size([16, 338]) two_labels_head
torch.Size([32, 355]) four_labels_head
torch.Size([16, 327]) two_labels_head
torch.Size([32, 335]) four_labels_head
torch.Size([16, 395]) two_labels_head
torch.Size([32, 353]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 299]) four_labels_head
torch.Size([16, 369]) two_labels_head
torch.Size([32, 303]) four_labels_head
torch.Size([16, 331]) two_labels_head
torch.Size([32, 372]) four_labels_head



torch.Size([16, 252]) two_labels_head
torch.Size([32, 469]) four_labels_head
torch.Size([16, 248]) two_labels_head
torch.Size([32, 376]) four_labels_head
torch.Size([16, 476]) two_labels_head
torch.Size([32, 323]) four_labels_head
torch.Size([16, 392]) two_labels_head
torch.Size([32, 348]) four_labels_head
torch.Size([16, 298]) two_labels_head
torch.Size([32, 431]) four_labels_head
torch.Size([16, 411]) two_labels_head
torch.Size([32, 323]) four_labels_head
torch.Size([16, 292]) two_labels_head
torch.Size([32, 307]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 391]) four_labels_head
torch.Size([16, 429]) two_labels_head
torch.Size([32, 315]) four_labels_head
torch.Size([16, 279]) two_labels_head
torch.Size([32, 371]) four_labels_head
torch.Size([16, 318]) two_labels_head
torch.Size([32, 385]) four_labels_head
torch.Size([16, 271]) two_labels_head
torch.Size([32, 242]) four_labels_head
torch.Size([16, 282]) two_labels_head
torch.Size([32, 271]) four_labels_head



torch.Size([16, 355]) two_labels_head
torch.Size([32, 322]) four_labels_head
torch.Size([16, 281]) two_labels_head
torch.Size([32, 448]) four_labels_head
torch.Size([16, 280]) two_labels_head
torch.Size([32, 288]) four_labels_head
torch.Size([16, 326]) two_labels_head
torch.Size([32, 357]) four_labels_head
torch.Size([16, 237]) two_labels_head
torch.Size([32, 264]) four_labels_head
torch.Size([16, 259]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 230]) two_labels_head
torch.Size([32, 291]) four_labels_head
torch.Size([16, 254]) two_labels_head
torch.Size([32, 285]) four_labels_head
torch.Size([16, 267]) two_labels_head
torch.Size([32, 309]) four_labels_head
torch.Size([16, 240]) two_labels_head
torch.Size([32, 317]) four_labels_head
torch.Size([16, 319]) two_labels_head
torch.Size([32, 403]) four_labels_head
torch.Size([16, 270]) two_labels_head
torch.Size([32, 334]) four_labels_head
torch.Size([16, 293]) two_labels_head
torch.Size([32, 512]) four_labels_head



torch.Size([16, 321]) two_labels_head
torch.Size([32, 290]) four_labels_head
torch.Size([16, 371]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 279]) two_labels_head
torch.Size([32, 440]) four_labels_head
torch.Size([16, 278]) two_labels_head
torch.Size([32, 270]) four_labels_head
torch.Size([16, 338]) two_labels_head
torch.Size([32, 332]) four_labels_head
torch.Size([16, 262]) two_labels_head
torch.Size([32, 311]) four_labels_head
torch.Size([16, 314]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 269]) two_labels_head
torch.Size([32, 388]) four_labels_head
torch.Size([16, 485]) two_labels_head
torch.Size([32, 305]) four_labels_head
torch.Size([16, 268]) two_labels_head
torch.Size([32, 271]) four_labels_head
torch.Size([16, 339]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 276]) two_labels_head
torch.Size([32, 354]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 512]) four_labels_head



torch.Size([16, 271]) two_labels_head
torch.Size([32, 351]) four_labels_head
torch.Size([16, 270]) two_labels_head
torch.Size([32, 283]) four_labels_head
torch.Size([16, 320]) two_labels_head
torch.Size([32, 427]) four_labels_head
torch.Size([16, 302]) two_labels_head
torch.Size([32, 343]) four_labels_head
torch.Size([16, 297]) two_labels_head
torch.Size([32, 328]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 366]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 371]) four_labels_head
torch.Size([16, 247]) two_labels_head
torch.Size([32, 289]) four_labels_head
torch.Size([16, 328]) two_labels_head
torch.Size([32, 388]) four_labels_head
torch.Size([16, 410]) two_labels_head
torch.Size([32, 290]) four_labels_head
torch.Size([16, 307]) two_labels_head
torch.Size([32, 432]) four_labels_head
torch.Size([16, 267]) two_labels_head
torch.Size([32, 318]) four_labels_head
torch.Size([16, 284]) two_labels_head
torch.Size([32, 417]) four_labels_head



torch.Size([16, 331]) two_labels_head
torch.Size([32, 405]) four_labels_head
torch.Size([16, 430]) two_labels_head
torch.Size([32, 320]) four_labels_head
torch.Size([16, 371]) two_labels_head
torch.Size([32, 338]) four_labels_head
torch.Size([16, 345]) two_labels_head
torch.Size([32, 434]) four_labels_head
torch.Size([16, 219]) two_labels_head
torch.Size([32, 259]) four_labels_head
torch.Size([16, 217]) two_labels_head
torch.Size([32, 341]) four_labels_head
torch.Size([16, 284]) two_labels_head
torch.Size([32, 301]) four_labels_head
torch.Size([16, 392]) two_labels_head
torch.Size([32, 376]) four_labels_head
torch.Size([16, 281]) two_labels_head
torch.Size([32, 413]) four_labels_head
torch.Size([16, 286]) two_labels_head
torch.Size([32, 306]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 277]) four_labels_head
torch.Size([16, 286]) two_labels_head
torch.Size([32, 293]) four_labels_head
torch.Size([16, 330]) two_labels_head
torch.Size([32, 349]) four_labels_head



torch.Size([16, 288]) two_labels_head
torch.Size([32, 287]) four_labels_head
torch.Size([16, 274]) two_labels_head
torch.Size([32, 399]) four_labels_head
torch.Size([16, 260]) two_labels_head
torch.Size([32, 355]) four_labels_head
torch.Size([16, 500]) two_labels_head
torch.Size([32, 358]) four_labels_head
torch.Size([16, 469]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 334]) two_labels_head
torch.Size([32, 409]) four_labels_head
torch.Size([16, 371]) two_labels_head
torch.Size([32, 284]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 460]) four_labels_head
torch.Size([16, 417]) two_labels_head
torch.Size([32, 384]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 400]) four_labels_head
torch.Size([16, 338]) two_labels_head
torch.Size([32, 292]) four_labels_head
torch.Size([16, 509]) two_labels_head
torch.Size([32, 397]) four_labels_head
torch.Size([16, 336]) two_labels_head
torch.Size([32, 325]) four_labels_head



torch.Size([16, 359]) two_labels_head
torch.Size([32, 334]) four_labels_head
torch.Size([16, 268]) two_labels_head
torch.Size([32, 455]) four_labels_head
torch.Size([16, 430]) two_labels_head
torch.Size([32, 335]) four_labels_head
torch.Size([16, 257]) two_labels_head
torch.Size([32, 301]) four_labels_head
torch.Size([16, 441]) two_labels_head
torch.Size([32, 373]) four_labels_head
torch.Size([16, 282]) two_labels_head
torch.Size([32, 380]) four_labels_head
torch.Size([16, 268]) two_labels_head
torch.Size([32, 429]) four_labels_head
torch.Size([16, 283]) two_labels_head
torch.Size([32, 302]) four_labels_head
torch.Size([16, 317]) two_labels_head
torch.Size([32, 276]) four_labels_head
torch.Size([16, 421]) two_labels_head
torch.Size([32, 455]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 375]) four_labels_head
torch.Size([16, 296]) two_labels_head
torch.Size([32, 329]) four_labels_head
torch.Size([16, 356]) two_labels_head
torch.Size([32, 284]) four_labels_head



torch.Size([16, 354]) two_labels_head
torch.Size([32, 309]) four_labels_head
torch.Size([16, 247]) two_labels_head
torch.Size([32, 317]) four_labels_head
torch.Size([16, 465]) two_labels_head
torch.Size([32, 377]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 334]) four_labels_head
torch.Size([16, 364]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 354]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 235]) two_labels_head
torch.Size([32, 284]) four_labels_head
torch.Size([16, 255]) two_labels_head
torch.Size([32, 290]) four_labels_head
torch.Size([16, 414]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 307]) two_labels_head
torch.Size([32, 443]) four_labels_head
torch.Size([16, 300]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 280]) two_labels_head
torch.Size([32, 327]) four_labels_head
torch.Size([16, 271]) two_labels_head
torch.Size([32, 512]) four_labels_head



torch.Size([16, 302]) two_labels_head
torch.Size([32, 291]) four_labels_head
torch.Size([16, 320]) two_labels_head
torch.Size([32, 346]) four_labels_head
torch.Size([16, 275]) two_labels_head
torch.Size([32, 375]) four_labels_head
torch.Size([16, 283]) two_labels_head
torch.Size([32, 332]) four_labels_head
torch.Size([16, 299]) two_labels_head
torch.Size([32, 450]) four_labels_head
torch.Size([16, 250]) two_labels_head
torch.Size([32, 305]) four_labels_head
torch.Size([16, 311]) two_labels_head
torch.Size([32, 380]) four_labels_head
torch.Size([16, 254]) two_labels_head
torch.Size([32, 294]) four_labels_head
torch.Size([16, 356]) two_labels_head
torch.Size([32, 336]) four_labels_head
torch.Size([16, 400]) two_labels_head
torch.Size([32, 265]) four_labels_head
torch.Size([16, 341]) two_labels_head
torch.Size([32, 404]) four_labels_head
torch.Size([16, 279]) two_labels_head
torch.Size([32, 355]) four_labels_head
torch.Size([16, 285]) two_labels_head
torch.Size([32, 323]) four_labels_head



torch.Size([16, 243]) two_labels_head
torch.Size([32, 360]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 313]) four_labels_head
torch.Size([16, 224]) two_labels_head
torch.Size([32, 313]) four_labels_head
torch.Size([16, 410]) two_labels_head
torch.Size([32, 461]) four_labels_head
torch.Size([16, 454]) two_labels_head
torch.Size([32, 302]) four_labels_head
torch.Size([16, 360]) two_labels_head
torch.Size([32, 296]) four_labels_head
torch.Size([16, 375]) two_labels_head
torch.Size([32, 257]) four_labels_head
torch.Size([16, 218]) two_labels_head
torch.Size([32, 502]) four_labels_head
torch.Size([16, 284]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 239]) two_labels_head
torch.Size([32, 430]) four_labels_head
torch.Size([16, 272]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 340]) two_labels_head
torch.Size([32, 369]) four_labels_head
torch.Size([16, 318]) two_labels_head
torch.Size([32, 302]) four_labels_head



torch.Size([16, 288]) two_labels_head
torch.Size([32, 378]) four_labels_head
torch.Size([16, 233]) two_labels_head
torch.Size([32, 351]) four_labels_head
torch.Size([16, 288]) two_labels_head
torch.Size([32, 453]) four_labels_head
torch.Size([16, 305]) two_labels_head
torch.Size([32, 329]) four_labels_head
torch.Size([16, 302]) two_labels_head
torch.Size([32, 311]) four_labels_head
torch.Size([16, 380]) two_labels_head
torch.Size([32, 378]) four_labels_head
torch.Size([16, 223]) two_labels_head
torch.Size([32, 329]) four_labels_head
torch.Size([16, 327]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 323]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 267]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 298]) four_labels_head
torch.Size([16, 322]) two_labels_head
torch.Size([32, 449]) four_labels_head




torch.Size([16, 323]) two_labels_head
torch.Size([32, 308]) four_labels_head
torch.Size([16, 294]) two_labels_head
torch.Size([32, 342]) four_labels_head
torch.Size([16, 289]) two_labels_head
torch.Size([32, 275]) four_labels_head
torch.Size([16, 287]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 464]) two_labels_head
torch.Size([32, 341]) four_labels_head
torch.Size([16, 347]) two_labels_head
torch.Size([32, 372]) four_labels_head
torch.Size([16, 248]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 292]) four_labels_head
torch.Size([16, 291]) two_labels_head
torch.Size([32, 306]) four_labels_head
torch.Size([16, 396]) two_labels_head
torch.Size([32, 302]) four_labels_head
torch.Size([16, 241]) two_labels_head
torch.Size([32, 403]) four_labels_head
torch.Size([16, 396]) two_labels_head
torch.Size([32, 364]) four_labels_head
torch.Size([16, 365]) two_labels_head
torch.Size([32, 282]) four_labels_head



torch.Size([16, 292]) two_labels_head
torch.Size([32, 368]) four_labels_head
torch.Size([16, 282]) two_labels_head
torch.Size([32, 341]) four_labels_head
torch.Size([16, 375]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 298]) four_labels_head
torch.Size([16, 326]) two_labels_head
torch.Size([32, 314]) four_labels_head
torch.Size([16, 373]) two_labels_head
torch.Size([32, 287]) four_labels_head
torch.Size([16, 248]) two_labels_head
torch.Size([32, 371]) four_labels_head
torch.Size([16, 285]) two_labels_head
torch.Size([32, 281]) four_labels_head
torch.Size([16, 384]) two_labels_head
torch.Size([32, 324]) four_labels_head
torch.Size([16, 253]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 270]) two_labels_head
torch.Size([32, 312]) four_labels_head
torch.Size([16, 360]) two_labels_head
torch.Size([32, 306]) four_labels_head
torch.Size([16, 269]) two_labels_head
torch.Size([32, 512]) four_labels_head



torch.Size([16, 277]) two_labels_head
torch.Size([32, 404]) four_labels_head
torch.Size([16, 413]) two_labels_head
torch.Size([32, 339]) four_labels_head
torch.Size([16, 361]) two_labels_head
torch.Size([32, 315]) four_labels_head
torch.Size([16, 403]) two_labels_head
torch.Size([32, 316]) four_labels_head
torch.Size([16, 256]) two_labels_head
torch.Size([32, 290]) four_labels_head
torch.Size([16, 216]) two_labels_head
torch.Size([32, 325]) four_labels_head
torch.Size([16, 327]) two_labels_head
torch.Size([32, 269]) four_labels_head
torch.Size([16, 251]) two_labels_head
torch.Size([32, 379]) four_labels_head
torch.Size([16, 243]) two_labels_head
torch.Size([32, 356]) four_labels_head
torch.Size([16, 298]) two_labels_head
torch.Size([32, 327]) four_labels_head
torch.Size([16, 376]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 253]) two_labels_head
torch.Size([32, 392]) four_labels_head
torch.Size([16, 325]) two_labels_head
torch.Size([32, 395]) four_labels_head



torch.Size([16, 287]) two_labels_head
torch.Size([32, 441]) four_labels_head
torch.Size([16, 235]) two_labels_head
torch.Size([32, 307]) four_labels_head
torch.Size([16, 356]) two_labels_head
torch.Size([32, 349]) four_labels_head
torch.Size([16, 308]) two_labels_head
torch.Size([32, 403]) four_labels_head
torch.Size([16, 292]) two_labels_head
torch.Size([32, 428]) four_labels_head
torch.Size([16, 298]) two_labels_head
torch.Size([32, 407]) four_labels_head
torch.Size([16, 292]) two_labels_head
torch.Size([32, 402]) four_labels_head
torch.Size([16, 280]) two_labels_head
torch.Size([32, 281]) four_labels_head
torch.Size([16, 304]) two_labels_head
torch.Size([32, 501]) four_labels_head
torch.Size([16, 298]) two_labels_head
torch.Size([32, 291]) four_labels_head
torch.Size([16, 461]) two_labels_head
torch.Size([32, 434]) four_labels_head
torch.Size([16, 367]) two_labels_head
torch.Size([32, 323]) four_labels_head
torch.Size([16, 282]) two_labels_head
torch.Size([32, 356]) four_labels_head



torch.Size([16, 269]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 290]) four_labels_head
torch.Size([16, 283]) two_labels_head
torch.Size([32, 349]) four_labels_head
torch.Size([16, 295]) two_labels_head
torch.Size([32, 268]) four_labels_head
torch.Size([16, 309]) two_labels_head
torch.Size([32, 502]) four_labels_head
torch.Size([16, 286]) two_labels_head
torch.Size([32, 343]) four_labels_head
torch.Size([16, 329]) two_labels_head
torch.Size([32, 398]) four_labels_head
torch.Size([16, 298]) two_labels_head
torch.Size([32, 342]) four_labels_head
torch.Size([16, 422]) two_labels_head
torch.Size([32, 297]) four_labels_head
torch.Size([16, 285]) two_labels_head
torch.Size([32, 351]) four_labels_head
torch.Size([16, 357]) two_labels_head
torch.Size([32, 369]) four_labels_head
torch.Size([16, 502]) two_labels_head
torch.Size([32, 326]) four_labels_head
torch.Size([16, 255]) two_labels_head
torch.Size([32, 359]) four_labels_head



torch.Size([16, 303]) two_labels_head
torch.Size([32, 291]) four_labels_head
torch.Size([16, 278]) two_labels_head
torch.Size([32, 373]) four_labels_head
torch.Size([16, 263]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 303]) two_labels_head
torch.Size([32, 325]) four_labels_head
torch.Size([16, 279]) two_labels_head
torch.Size([32, 369]) four_labels_head
torch.Size([16, 256]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 309]) two_labels_head
torch.Size([32, 336]) four_labels_head
torch.Size([16, 308]) two_labels_head
torch.Size([32, 384]) four_labels_head
torch.Size([16, 254]) two_labels_head
torch.Size([32, 400]) four_labels_head
torch.Size([16, 312]) two_labels_head
torch.Size([32, 302]) four_labels_head
torch.Size([16, 246]) two_labels_head
torch.Size([32, 306]) four_labels_head
torch.Size([16, 337]) two_labels_head
torch.Size([32, 439]) four_labels_head



torch.Size([16, 315]) two_labels_head
torch.Size([32, 408]) four_labels_head
torch.Size([16, 299]) two_labels_head
torch.Size([32, 378]) four_labels_head
torch.Size([16, 341]) two_labels_head
torch.Size([32, 257]) four_labels_head
torch.Size([16, 277]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 344]) two_labels_head
torch.Size([32, 344]) four_labels_head
torch.Size([16, 301]) two_labels_head
torch.Size([32, 322]) four_labels_head
torch.Size([16, 263]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 344]) two_labels_head
torch.Size([32, 330]) four_labels_head
torch.Size([16, 291]) two_labels_head
torch.Size([32, 390]) four_labels_head
torch.Size([16, 280]) two_labels_head
torch.Size([32, 325]) four_labels_head
torch.Size([16, 275]) two_labels_head
torch.Size([32, 425]) four_labels_head
torch.Size([16, 278]) two_labels_head
torch.Size([32, 322]) four_labels_head
torch.Size([16, 277]) two_labels_head
torch.Size([32, 367]) four_labels_head



torch.Size([16, 346]) two_labels_head
torch.Size([32, 315]) four_labels_head
torch.Size([16, 354]) two_labels_head
torch.Size([32, 455]) four_labels_head
torch.Size([16, 295]) two_labels_head
torch.Size([32, 392]) four_labels_head
torch.Size([16, 329]) two_labels_head
torch.Size([32, 363]) four_labels_head
torch.Size([16, 297]) two_labels_head
torch.Size([32, 418]) four_labels_head
torch.Size([16, 264]) two_labels_head
torch.Size([32, 318]) four_labels_head
torch.Size([16, 320]) two_labels_head
torch.Size([32, 490]) four_labels_head
torch.Size([16, 228]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 292]) two_labels_head
torch.Size([32, 353]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 390]) four_labels_head
torch.Size([16, 344]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 512]) two_labels_head
torch.Size([32, 316]) four_labels_head
torch.Size([16, 328]) two_labels_head
torch.Size([32, 355]) four_labels_head



torch.Size([16, 305]) two_labels_head
torch.Size([32, 344]) four_labels_head
torch.Size([16, 326]) two_labels_head
torch.Size([32, 351]) four_labels_head
torch.Size([16, 330]) two_labels_head
torch.Size([32, 298]) four_labels_head
torch.Size([16, 283]) two_labels_head
torch.Size([32, 418]) four_labels_head
torch.Size([16, 295]) two_labels_head
torch.Size([32, 512]) four_labels_head
torch.Size([16, 388]) two_labels_head
torch.Size([32, 475]) four_labels_head
torch.Size([16, 360]) two_labels_head
torch.Size([32, 316]) four_labels_head
torch.Size([16, 389]) two_labels_head
torch.Size([32, 333]) four_labels_head
torch.Size([16, 251]) two_labels_head
torch.Size([32, 317]) four_labels_head
torch.Size([16, 262]) two_labels_head
torch.Size([32, 298]) four_labels_head
torch.Size([16, 377]) two_labels_head
torch.Size([32, 358]) four_labels_head
torch.Size([16, 350]) two_labels_head
torch.Size([32, 326]) four_labels_head
torch.Size([16, 296]) two_labels_head
torch.Size([32, 312]) four_labels_head

781it [00:11, 65.18it/s]
  0%|          | 0/1 [00:11<?, ?it/s]


torch.Size([16, 281]) two_labels_head
torch.Size([32, 287]) four_labels_head


KeyboardInterrupt: 