In [1]:
import os
os.environ['http_proxy'] = "http://192.41.170.23:3128" 
os.environ['https_proxy'] = "http://192.41.170.23:3128" 
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.cuda.get_device_name(device)

'NVIDIA RTX A6000'

In [3]:
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer, AutoModelForSequenceClassification
text = "This is a great [MASK]."
#text = "The capital of France is [MASK]."
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint).to(device)

In [4]:
### ALL token Fill
inputs = tokenizer(text, return_tensors="pt").to(device)
token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'",token)

'>>> This is a great deal.' 3066
'>>> This is a great success.' 3112
'>>> This is a great adventure.' 6172
'>>> This is a great idea.' 2801
'>>> This is a great feat.' 8658


# Template 

In [5]:
# template must have 3 things : input, symbol, mask token 
s1,s2,s3,s4 = ":" , ">" , "-" , "," 
mask = " [MASK]"
inputs =  "xxx" 
t1 = inputs + s1 + mask 
t2 = inputs + s2 + mask 
t3 = inputs + s3 + mask 
t4 = inputs + s4 + mask 

In [6]:
# find verbalizer token 
ids_labels = tokenizer.convert_tokens_to_ids(['good','bad'])
print(tokenizer.convert_ids_to_tokens(ids_labels) , ids_labels)

['good', 'bad'] [2204, 2919]


In [7]:
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer, AutoModelForSequenceClassification
text = "This is a great [MASK]."
#text = "The capital of France is [MASK]."
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint).to(device)

In [8]:
prompts = tokenizer(t1,padding=True ,return_tensors="pt").to(device)
token_logits = model(**prompts).logits

# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(prompts["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, ids_labels]

# sort 
sort_scores = torch.argsort(mask_token_logits,descending=True)
scores = torch.argmax(mask_token_logits)
print(scores)

for idx in sort_scores:
    token = ids_labels[idx]
    print(f"'{t1.replace(tokenizer.mask_token, tokenizer.decode([token]))}'",idx)
          #,token_logits[0, mask_token_index, ids_labels[idx]])

tensor(0, device='cuda:0')
'xxx: good' tensor(0, device='cuda:0')
'xxx: bad' tensor(1, device='cuda:0')


# sst2 

In [9]:
from datasets import load_dataset
sst = load_dataset("sst2")
sst 

Found cached dataset sst2 (/home/arnajakt/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 872
    })
    test: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 1821
    })
})

In [27]:
trainsst = sst['train']
valsst = sst['validation']
testsst = sst['test']
trainsst

Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 67349
})

In [28]:
prompts = list(map(lambda sentence : sentence + ": " + "[MASK]" , trainsst["sentence"]))
trainsst = trainsst.add_column('prompts',prompts)
trainsst = trainsst.remove_columns(['sentence'])
trainsst

Dataset({
    features: ['idx', 'label', 'prompts'],
    num_rows: 67349
})

In [29]:
from transformers import AutoTokenizer , AutoModelForMaskedLM

model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def tokenize_function(examples):
    return tokenizer(examples["prompts"], padding="max_length", truncation=True)

tokenized_datasets = trainsst.map(tokenize_function, batched=True)

Loading cached processed dataset at /home/arnajakt/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5/cache-22c6c97846910931.arrow


In [30]:
tokenized_datasets = tokenized_datasets.remove_columns(["prompts","idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 67349
})

In [31]:
small_train_dataset = tokenized_datasets.shuffle(seed=42).select(range(1000))

Loading cached shuffled indices for dataset at /home/arnajakt/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5/cache-c131f29b5360a69d.arrow


In [32]:
def token_data(dataset):
    prompts = list(map(lambda sentence : sentence + ": " + "[MASK]" , dataset["sentence"]))
    dataset = dataset.add_column('prompts',prompts)
    dataset = dataset.remove_columns(['sentence'])
    
    tokenized_datasets = trainsst.map(tokenize_function, batched=True)
    tokenized_datasets = tokenized_datasets.remove_columns(["prompts","idx"])
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
    tokenized_datasets.set_format("torch")
    return tokenized_datasets

In [33]:
eval_dataset = token_data(valsst)
small_eval_dataset = eval_dataset.shuffle(seed=42).select(range(1000))

Loading cached processed dataset at /home/arnajakt/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5/cache-22c6c97846910931.arrow
Loading cached shuffled indices for dataset at /home/arnajakt/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5/cache-c131f29b5360a69d.arrow


In [34]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)

In [35]:
# Just check 
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

{'labels': torch.Size([8]),
 'input_ids': torch.Size([8, 512]),
 'attention_mask': torch.Size([8, 512])}

In [36]:
from transformers import AutoModelForMaskedLM
model_checkpoint = "distilbert-base-uncased"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint).to(device)

In [37]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

In [38]:
from accelerate import Accelerator

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [42]:
num_training_steps

375

In [39]:
from transformers import get_scheduler

num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [47]:
from tqdm.auto import tqdm
import torch
import math

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for batch in train_dataloader:
        global x 
        x = batch
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        loss = outputs.loss
        losses.append(accelerator.gather(loss.repeat(batch_size)))

    losses = torch.cat(losses)
    losses = losses[: len(eval_dataset)]
    try:
        perplexity = math.exp(torch.mean(losses))
    except OverflowError:
        perplexity = float("inf")

    print(f">>> Epoch {epoch}: Perplexity: {perplexity}")

    # Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        repo.push_to_hub(
            commit_message=f"Training in progress epoch {epoch}", blocking=False
        )

  0%|          | 0/375 [00:00<?, ?it/s]

ValueError: Expected input batch_size (4096) to match target batch_size (8).

In [48]:
x

{'labels': tensor([1, 1, 1, 0, 0, 0, 1, 1], device='cuda:0'),
 'input_ids': tensor([[  101, 10973,  1011,  ...,     0,     0,     0],
         [  101,  2307,  6018,  ...,     0,     0,     0],
         [  101,  2092,  1011,  ...,     0,     0,     0],
         ...,
         [  101,  2003,  1996,  ...,     0,     0,     0],
         [  101,  2293,  1996,  ...,     0,     0,     0],
         [  101,  6429,  1024,  ...,     0,     0,     0]], device='cuda:0'),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')}