In [None]:
from transformers import AutoTokenizer, AutoModel
import torch

# Load the multilingual BERT model and tokenizer
model_name = "distilbert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.to(torch.device("mps"))

# Example text in multiple languages
text = "This is an English sentence. Это предложение на русском. 这是一句中文。"

# Tokenize the text
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to("mps")

# Pass through the model
outputs = model(**inputs)

# Print the output
print(outputs.last_hidden_state)  # The embeddings for each token


In [None]:
import torchinfo

print(torchinfo.summary(model))  # Batch size = 1, sequence length = 128

In [None]:
import torchviz
from IPython.display import Image
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Generate and render the graph
dot = torchviz.make_dot(outputs.last_hidden_state, params=dict(model.named_parameters()))
file_path = "model_graph"
dot.render(file_path, format="png")

# Display the image
Image(file_path+".png")


In [15]:
print(model)

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): DistilBertSdpaAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): 