In [2]:
from datasets import load_dataset

# Dataset

In [3]:
dataset = load_dataset("yelp_review_full")

In [4]:
dataset["train"][100]

{'label': 0,
 'text': 'My expectations for McDonalds are t rarely high. But for one to still fail so spectacularly...that takes something special!\\nThe cashier took my friends\'s order, then promptly ignored me. I had to force myself in front of a cashier who opened his register to wait on the person BEHIND me. I waited over five minutes for a gigantic order that included precisely one kid\'s meal. After watching two people who ordered after me be handed their food, I asked where mine was. The manager started yelling at the cashiers for \\"serving off their orders\\" when they didn\'t have their food. But neither cashier was anywhere near those controls, and the manager was the one serving food to customers and clearing the boards.\\nThe manager was rude when giving me my order. She didn\'t make sure that I had everything ON MY RECEIPT, and never even had the decency to apologize that I felt I was getting poor service.\\nI\'ve eaten at various McDonalds restaurants for over 30 years. 

In [12]:
from transformers import AutoTokenizer

In [13]:
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")

In [14]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

In [15]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [16]:
from torch.utils.data import DataLoader

In [17]:
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets = tokenized_datasets.remove_columns("token_type_ids")
tokenized_datasets.set_format("torch")

In [18]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))

In [19]:
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [20]:
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)

In [21]:
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)

In [22]:
from transformers import AutoModelForSequenceClassification

In [23]:
model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
from torch.optim import AdamW

In [25]:
optimizer = AdamW(model.parameters(), lr=5e-5)

In [26]:
from transformers import get_scheduler

In [27]:
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [28]:
import torch

In [29]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

# MRL

In [30]:
from typing import List

import torch
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Optional
from transformers import AutoConfig, AutoModel

In [39]:
class Matryoshka_CE_Loss(nn.Module):
    def __init__(self, relative_importance: List[float]=None, **kwargs):
        super(Matryoshka_CE_Loss, self).__init__()
        self.criterion = nn.CrossEntropyLoss(**kwargs)

    def forward(self, output, target):
        # output shape: [G granularities, N batch size, C number of classes]
        # target shape: [N batch size]
        # Calculate losses for each output and stack them. This is still O(N)
        losses = torch.stack([self.criterion(output_i, target) for output_i in output])
        
        # Set relative_importance to 1 if not specified
        rel_importance = torch.ones_like(losses)
        
        # Apply relative importance weights
        weighted_losses = rel_importance * losses
        return weighted_losses.sum()

In [40]:
class MRL_Linear_Layer(nn.Module):
	def __init__(self, nesting_list: List, num_classes=5, efficient=False, **kwargs):
		super(MRL_Linear_Layer, self).__init__()
		self.nesting_list = nesting_list
		self.num_classes = num_classes # Number of classes for classification
		for i, num_feat in enumerate(self.nesting_list):
			setattr(self, f"nesting_classifier_{i}", nn.Linear(num_feat, self.num_classes, **kwargs))	

	def reset_parameters(self):
		for i in range(len(self.nesting_list)):
				getattr(self, f"nesting_classifier_{i}").reset_parameters()


	def forward(self, x):
		nesting_logits = ()
		for i, num_feat in enumerate(self.nesting_list):
			nesting_logits +=  (getattr(self, f"nesting_classifier_{i}")(x[:, :num_feat]),)

		return nesting_logits

In [41]:
class CustomModel(nn.Module):
  def __init__(self,num_labels): 
    super(CustomModel,self).__init__() 
    self.num_labels = num_labels

    #Load Model with given checkpoint and extract its body

    config = AutoConfig.from_pretrained('google-bert/bert-base-cased', output_attention=True, output_hidden_states=True)

    self.model = model = AutoModel.from_pretrained('google-bert/bert-base-cased', config=config)

    # self.model = model = AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=True,output_hidden_states=True))
    self.dropout = nn.Dropout(0.1) 
    nesting_list = [8, 16, 32, 64, 128, 256, 512, 768]
    self.classifier = MRL_Linear_Layer(nesting_list,num_labels) # load and initialize weights

  def forward(self, input_ids=None, attention_mask=None,labels=None):
    #Extract outputs from the body
    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

    #Add custom layers
    sequence_output = self.dropout(outputs[0]) #outputs[0]=last hidden state

    logits = self.classifier(sequence_output[:,0,:].view(-1,768)) # calculate losses
    
    loss = None
    if labels is not None:
      loss_fct = Matryoshka_CE_Loss()
      loss = loss_fct(logits, labels)

    return loss
    
    # return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions)

In [42]:
model = CustomModel(num_labels=5)

In [43]:
model.to(device)

CustomModel(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affi

In [44]:
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
# model.to(device)

In [45]:
from tqdm.auto import tqdm

In [None]:
progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs
        print(loss.detach().cpu())
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

  0%|                                                   | 0/375 [00:31<?, ?it/s]


tensor(13.0686)


  0%|                                         | 1/375 [00:10<1:07:50, 10.88s/it]

tensor(12.9162)


  1%|▏                                          | 2/375 [00:17<51:08,  8.23s/it]

tensor(13.0066)


  1%|▎                                          | 3/375 [00:23<46:00,  7.42s/it]

tensor(12.9278)


  1%|▍                                          | 4/375 [00:30<44:13,  7.15s/it]

tensor(13.0962)


  1%|▌                                          | 5/375 [00:37<43:56,  7.12s/it]

tensor(13.3258)


  2%|▋                                          | 6/375 [00:44<44:22,  7.22s/it]

tensor(12.9934)


  2%|▊                                          | 7/375 [00:52<44:28,  7.25s/it]

tensor(13.2176)


  2%|▉                                          | 8/375 [00:59<44:32,  7.28s/it]

tensor(12.8531)


  2%|█                                          | 9/375 [01:07<44:49,  7.35s/it]

tensor(12.9353)


  3%|█                                         | 10/375 [01:15<46:14,  7.60s/it]

tensor(13.0543)


  3%|█▏                                        | 11/375 [01:22<45:15,  7.46s/it]

tensor(13.4046)


  3%|█▎                                        | 12/375 [01:29<44:20,  7.33s/it]

tensor(12.9141)


  3%|█▍                                        | 13/375 [01:36<43:59,  7.29s/it]

tensor(12.8956)


  4%|█▌                                        | 14/375 [01:43<43:48,  7.28s/it]

tensor(13.0059)


  4%|█▋                                        | 15/375 [01:51<44:34,  7.43s/it]

tensor(13.2094)


  4%|█▊                                        | 16/375 [01:59<44:54,  7.50s/it]

tensor(13.1368)


  5%|█▉                                        | 17/375 [02:07<45:36,  7.64s/it]

tensor(13.0076)


  5%|██                                        | 18/375 [02:14<44:52,  7.54s/it]

tensor(13.3142)


  5%|██▏                                       | 19/375 [02:21<43:58,  7.41s/it]

tensor(13.0356)


  5%|██▏                                       | 20/375 [02:29<43:56,  7.43s/it]

tensor(12.8599)


  6%|██▎                                       | 21/375 [02:36<43:34,  7.39s/it]

tensor(12.8706)


  6%|██▍                                       | 22/375 [02:43<43:32,  7.40s/it]

tensor(13.1689)


  6%|██▌                                       | 23/375 [02:51<43:16,  7.38s/it]

tensor(13.1738)


  6%|██▋                                       | 24/375 [02:58<42:44,  7.31s/it]

tensor(13.2017)


  7%|██▊                                       | 25/375 [03:05<42:27,  7.28s/it]

tensor(13.1743)


  7%|██▉                                       | 26/375 [03:13<42:45,  7.35s/it]

tensor(13.0781)


  7%|███                                       | 27/375 [03:20<42:11,  7.27s/it]

tensor(12.8563)


  7%|███▏                                      | 28/375 [03:28<43:19,  7.49s/it]

tensor(13.2445)


  8%|███▏                                      | 29/375 [03:36<43:55,  7.62s/it]

tensor(13.1747)


  8%|███▎                                      | 30/375 [03:44<44:29,  7.74s/it]

tensor(13.1469)


  8%|███▍                                      | 31/375 [03:52<44:58,  7.85s/it]

tensor(13.0126)


  9%|███▌                                      | 32/375 [04:00<45:44,  8.00s/it]

tensor(12.8679)


  9%|███▋                                      | 33/375 [04:08<45:57,  8.06s/it]

tensor(13.2227)


  9%|███▊                                      | 34/375 [04:16<45:37,  8.03s/it]

tensor(13.4415)


  9%|███▉                                      | 35/375 [04:25<46:26,  8.20s/it]

tensor(13.2660)


 10%|████                                      | 36/375 [04:33<46:30,  8.23s/it]

tensor(13.0309)


 10%|████▏                                     | 37/375 [04:42<46:50,  8.31s/it]

tensor(12.9762)


 10%|████▎                                     | 38/375 [04:50<46:15,  8.24s/it]

tensor(13.0806)


 10%|████▎                                     | 39/375 [04:59<47:04,  8.41s/it]

tensor(13.0962)


 11%|████▍                                     | 40/375 [05:06<45:00,  8.06s/it]

tensor(13.0544)


 11%|████▌                                     | 41/375 [05:13<43:13,  7.76s/it]

tensor(13.3370)


 11%|████▋                                     | 42/375 [05:20<42:38,  7.68s/it]

tensor(13.2132)


 11%|████▊                                     | 43/375 [05:28<41:59,  7.59s/it]

tensor(13.1580)


 12%|████▉                                     | 44/375 [05:35<41:23,  7.50s/it]

tensor(13.2022)


 12%|█████                                     | 45/375 [05:43<41:31,  7.55s/it]

tensor(12.8106)


 12%|█████▏                                    | 46/375 [05:50<40:37,  7.41s/it]

tensor(13.1849)


 13%|█████▎                                    | 47/375 [05:57<39:50,  7.29s/it]

tensor(13.1091)


 13%|█████▍                                    | 48/375 [06:04<39:38,  7.27s/it]

tensor(12.8982)


 13%|█████▍                                    | 49/375 [06:11<39:36,  7.29s/it]

tensor(12.7810)


 13%|█████▌                                    | 50/375 [06:19<39:28,  7.29s/it]

tensor(13.0557)


 14%|█████▋                                    | 51/375 [06:26<38:58,  7.22s/it]

tensor(13.0392)


 14%|█████▊                                    | 52/375 [06:33<39:13,  7.29s/it]

tensor(13.2381)


 14%|█████▉                                    | 53/375 [06:40<39:04,  7.28s/it]

tensor(13.1486)


 14%|██████                                    | 54/375 [06:49<40:56,  7.65s/it]

tensor(13.1639)


 15%|██████▏                                   | 55/375 [06:57<42:20,  7.94s/it]

tensor(13.3751)


 15%|██████▎                                   | 56/375 [07:05<41:20,  7.78s/it]

tensor(13.1694)


 15%|██████▍                                   | 57/375 [07:12<40:47,  7.70s/it]

tensor(13.3378)


 15%|██████▍                                   | 58/375 [07:20<39:48,  7.53s/it]

tensor(13.0078)


 16%|██████▌                                   | 59/375 [07:27<39:16,  7.46s/it]

tensor(13.1890)


 16%|██████▋                                   | 60/375 [07:34<39:21,  7.50s/it]

tensor(13.1058)


 16%|██████▊                                   | 61/375 [07:42<39:28,  7.54s/it]

tensor(13.3771)


 17%|██████▉                                   | 62/375 [07:49<38:37,  7.40s/it]

tensor(12.8947)


 17%|███████                                   | 63/375 [07:57<39:19,  7.56s/it]

tensor(13.0744)


 17%|███████▏                                  | 64/375 [08:04<38:53,  7.50s/it]

tensor(13.0711)


 17%|███████▎                                  | 65/375 [08:11<38:02,  7.36s/it]

tensor(13.4896)


 18%|███████▍                                  | 66/375 [08:19<38:26,  7.46s/it]

tensor(12.8308)


 18%|███████▌                                  | 67/375 [08:27<38:41,  7.54s/it]

tensor(12.9575)


 18%|███████▌                                  | 68/375 [08:34<38:18,  7.49s/it]

tensor(12.9208)


 18%|███████▋                                  | 69/375 [08:42<38:11,  7.49s/it]

tensor(13.4065)


 19%|███████▊                                  | 70/375 [08:49<38:00,  7.48s/it]

tensor(13.1308)


 19%|███████▉                                  | 71/375 [08:57<37:52,  7.48s/it]

tensor(12.8992)


 19%|████████                                  | 72/375 [09:05<38:26,  7.61s/it]

tensor(13.1115)


 19%|████████▏                                 | 73/375 [09:12<37:38,  7.48s/it]

tensor(12.8452)


 20%|████████▎                                 | 74/375 [09:19<36:55,  7.36s/it]

tensor(13.1998)


 20%|████████▍                                 | 75/375 [09:26<36:39,  7.33s/it]

tensor(12.9301)


 20%|████████▌                                 | 76/375 [09:35<38:39,  7.76s/it]

tensor(13.4862)


 21%|████████▌                                 | 77/375 [09:43<38:49,  7.82s/it]

tensor(12.9615)


 21%|████████▋                                 | 78/375 [09:51<38:45,  7.83s/it]

tensor(12.6863)


 21%|████████▊                                 | 79/375 [09:58<38:09,  7.73s/it]

tensor(13.0989)


 21%|████████▉                                 | 80/375 [10:06<37:23,  7.61s/it]

tensor(13.2404)


 22%|█████████                                 | 81/375 [10:13<37:20,  7.62s/it]

tensor(13.2904)


 22%|█████████▏                                | 82/375 [10:21<37:36,  7.70s/it]

tensor(12.9515)


 22%|█████████▎                                | 83/375 [10:29<37:55,  7.79s/it]

tensor(12.9453)


 22%|█████████▍                                | 84/375 [10:38<39:08,  8.07s/it]

tensor(13.1111)


 23%|█████████▌                                | 85/375 [10:46<38:31,  7.97s/it]

tensor(13.1983)


 23%|█████████▋                                | 86/375 [10:53<38:13,  7.94s/it]

tensor(12.9074)


 23%|█████████▋                                | 87/375 [11:01<37:39,  7.85s/it]

tensor(13.0653)


 23%|█████████▊                                | 88/375 [11:09<37:21,  7.81s/it]

tensor(12.9111)


 24%|█████████▉                                | 89/375 [11:16<36:16,  7.61s/it]

tensor(13.3256)


 24%|██████████                                | 90/375 [11:23<36:00,  7.58s/it]

tensor(13.0590)


 24%|██████████▏                               | 91/375 [11:31<35:28,  7.49s/it]

tensor(13.5669)


 25%|██████████▎                               | 92/375 [11:38<35:38,  7.56s/it]

tensor(13.1383)


 25%|██████████▍                               | 93/375 [11:46<35:42,  7.60s/it]

tensor(13.1040)


 25%|██████████▌                               | 94/375 [11:53<35:18,  7.54s/it]

tensor(13.2192)


 25%|██████████▋                               | 95/375 [12:01<34:48,  7.46s/it]