<a href="https://colab.research.google.com/github/anna-marshalova/model-compression-course/blob/main/mbert_compression_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!pip install transformers datasets

Collecting transformers
  Downloading transformers-4.33.1-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m72.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.15.1 (from transformers)
  Downloading huggingface_hub-0.17.1-py3-none-any.whl (294 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.8/294.8 kB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m117.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloa

In [None]:
from transformers import AutoTokenizer

MODEL_NAME = 'bert-base-multilingual-cased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

In [None]:
from torch.nn import Module, Linear
from transformers import AutoModel


class Model(Module):
    def __init__(self, pretrained_model_name, num_labels, freeze = True, **kwargs):
        super().__init__()
        self.emb = AutoModel.from_pretrained(pretrained_model_name, output_attentions=False, output_hidden_states=False)
        if freeze:
            for param in self.emb.parameters():
                param.requires_grad = False
        hid_size = list(self.emb.parameters())[-1].shape[0]
        self.out = Linear(hid_size, num_labels)

    def forward(self, input_ids, attention_mask, **kwargs):
        x = self.emb(input_ids = input_ids, attention_mask = attention_mask).pooler_output
        return self.out(x)

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
CHECKPOINTS_PATH = '/content/drive/MyDrive/model_compression/model.pt'
model = torch.load(CHECKPOINTS_PATH, map_location = device)

In [None]:
import inference
from importlib import reload
reload(inference)

<module 'inference' from '/content/inference.py'>

In [None]:
from inference import get_test_data

In [None]:
test_dataset, test_loader = get_test_data()

Downloading builder script:   0%|          | 0.00/6.23k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/26.2k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/22.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/774k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/895k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/270399 [00:00<?, ? examples/s]

----------------------------------------------------------------------------------------------------
                                                text     label  source  \
0    yang memerlukan pemerhatian dan tindakan serius  positive  malaya   
1  sentiasa memikirkan dan merancang inisiatif ba...  positive  malaya   
2  Kita akan tengok daripada pelbagai aspek supay...  positive  malaya   
3  justeru asean perlu mengambil tindakan sebagai...  positive  malaya   
4  @_Niiar_ Jangan punah dulu, aku belum ke labua...   neutral  malaya   

  language  
0    malay  
1    malay  
2    malay  
3    malay  
4    malay  
----------------------------------------------------------------------------------------------------


Generating validation split:   0%|          | 0/10857 [00:00<?, ? examples/s]

----------------------------------------------------------------------------------------------------
                                                text     label  source  \
0  kerja sama apec perlu wujudkan pertumbuhan eko...  positive  malaya   
1    Ada yang disembunyikan, dan itu yang dipercayai  negative  malaya   
2  Apatah lagi penduduk negara ini terdiri daripa...  positive  malaya   
3  tetapi sejak krisis kewangan asia hingga sekar...  negative  malaya   
4  KOTA KINABALU: Warga emas tertua di Sabah mene...  positive  malaya   

  language  
0    malay  
1    malay  
2    malay  
3    malay  
4    malay  
----------------------------------------------------------------------------------------------------


Generating test split:   0%|          | 0/14465 [00:00<?, ? examples/s]

----------------------------------------------------------------------------------------------------
                                                text     label  source  \
0  Sepatutnya berbuat begitu  demi untuk menawark...  positive  malaya   
1  Alhamdulillah, sama2 bantu kerajaan memerangi ...  positive  malaya   
2  Biasanya bantuan disalurkan kepada sekolah ber...  positive  malaya   
3  Kerajaan wajar mengkaji semula had tunggakan c...  positive  malaya   
4  me; those everytime nak beli baju or seluar ta...  positive  malaya   

  language  
0    malay  
1    malay  
2    malay  
3    malay  
4    malay  
----------------------------------------------------------------------------------------------------


In [None]:
from inference import eval_all_langs, scores_to_df, measure_size_mb, measure_inference_time

In [None]:
test_langs = set(dataset['test']['language'])
scores = eval_all_langs(test_dataset, model, test_langs)

In [None]:
scores_to_df(scores)

Unnamed: 0,language,f1
0,chinese,0.556626
1,arabic,0.344552
2,portuguese,0.36925
3,japanese,0.499123
4,malay,0.565781
5,english,0.343242
6,hindi,0.359026
7,german,0.31625
8,indonesian,0.327033
9,spanish,0.272859


In [None]:
measure_size_mb(model)

Model size: 678.474MB


In [None]:
measure_inference_time(model, test_loader)

  0%|          | 0/46 [00:00<?, ?it/s]

Avg inference time: 18.049402236938477
