In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, AdamW, get_scheduler
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset, load_metric
from tqdm.auto import tqdm

In [2]:
checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=10)

In [3]:
new_classifier = torch.nn.Linear(model.config.hidden_size, 10)
model.classifier = new_classifier

In [4]:
dataset = load_dataset("csv", data_files="dataset_sheet.csv", split='train')
dataset = dataset.class_encode_column("label")
full_dataset = dataset.train_test_split(test_size=0.2, stratify_by_column="label")

In [5]:
def tokenize_function(example):
    return tokenizer(example["prompt"], truncation=True, return_tensors="pt", padding=True)

tokenized_datasets = full_dataset.map(tokenize_function, batched=True)

In [6]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tokenized_datasets = tokenized_datasets.remove_columns(["prompt"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets["train"].column_names

train_dataloader = DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
)

eval_dataloader = DataLoader(
    tokenized_datasets["test"], batch_size=8, collate_fn=data_collator
)

In [7]:
optimizer = AdamW(model.parameters(), lr=0.00005)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    progress_bar.clear()
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

In [8]:
accuracy_sum = 0
accuracy_avg = 0

In [9]:
metric = load_metric("f1")
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    current_accuracy = metric.compute(predictions=predictions, references=batch["labels"], average="weighted")
    print(current_accuracy['f1'])
    accuracy_sum += current_accuracy['f1']

In [10]:
print(accuracy_sum)
accuracy_avg = accuracy_sum / 25
print(accuracy_avg)

In [11]:
output_to_text = {0: "forward", 1: "list", 2: "read", 3: "reply", 4: "send", 5: "star", 6: "trash", 7: "trash_list", 8: "unknown", 9: "untrash"}

In [12]:
example_inputs = ["write a new mail", "delete this mail", "mark this mail as important", "read this mail", "show inbox", "list trash mails", "please untrash this mail", "The weather is nice today.", "i'd like to reply to this mail", "Please forward this mail"]
for example_input in example_inputs:
    encoded_input = tokenizer(example_input, return_tensors="pt").to("cuda")
    with torch.no_grad():
        output = model(**encoded_input)
        logits = output.logits
        predicted_intent = logits.argmax(-1).item()
        print(f"{example_input} -> {predicted_intent}")

In [13]:
model.save_pretrained("trained_model_new")
tokenizer.save_pretrained("trained_tokenizer_new")

In [14]:
test_model = AutoModelForSequenceClassification.from_pretrained("trained_model_new")
test_tokenizer = AutoTokenizer.from_pretrained("trained_tokenizer_new")

In [15]:
example_input = "Trash this mail now."
encoded_input = tokenizer(example_input, return_tensors="pt").to("cuda")
with torch.no_grad():
    output = model(**encoded_input)
    logits = output.logits
    predicted_intent = logits.argmax(-1).item()
    print(f"{example_input} -> {output_to_text[predicted_intent]}")