In [1]:
import torch
from datasets import load_from_disk
from transformers import (
    BartTokenizer,
    BartForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Загрузка датасета
dataset = load_from_disk('dataset')
train_dataset = dataset['train']
test_dataset = dataset['test']

# Инициализация токенизатора и модели
model_name = 'facebook/bart-base'
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name).to('cuda')

In [3]:
# Функция для анализа длин в токенах
def analyze_token_lengths(dataset, column_name):
    token_lengths = [len(tokenizer.encode(text)) for text in dataset[column_name]]
    max_length = max(token_lengths)
    min_length = min(token_lengths)
    avg_length = sum(token_lengths) / len(token_lengths)
    p95_length = sorted(token_lengths)[int(len(token_lengths) * 0.95)]
    return {
        "max": max_length,
        "min": min_length,
        "avg": avg_length,
        "p95": p95_length
    }

# Анализ длин входных данных в токенах
input_token_stats = analyze_token_lengths(train_dataset, "test_scenario")
print("Токенизированные длины для test_scenario:")
print(input_token_stats)

# Анализ длин выходных данных в токенах
target_token_stats = analyze_token_lengths(train_dataset, "test_steps")
print("\nТокенизированные длины для test_steps:")
print(target_token_stats)

Токенизированные длины для test_scenario:
{'max': 31, 'min': 5, 'avg': 15.987961476725522, 'p95': 22}

Токенизированные длины для test_steps:
{'max': 302, 'min': 8, 'avg': 43.5920278223649, 'p95': 68}


In [3]:

# Предобработка данных
def preprocess_function(examples):
    inputs = examples['test_scenario']
    targets = examples['test_steps']
    model_inputs = tokenizer(
        inputs,
        max_length=32,
        truncation=True,
        padding='max_length'
    )
    labels = tokenizer(
        text_target=targets,
        max_length=256,
        truncation=True,
        padding='max_length'
    )
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_test = test_dataset.map(preprocess_function, batched=True)

In [4]:
# Настройка data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Конфигурация обучения
training_args = Seq2SeqTrainingArguments(
    output_dir='./models',
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_steps=10,
    save_total_limit=2,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=2e-5,
    weight_decay=0.01,
    predict_with_generate=True,
    bf16=True,  # Включаем bfloat16
    fp16=False,
    report_to="none",
    load_best_model_at_end=True
)

# Инициализация тренера
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    data_collator=data_collator,
    tokenizer=tokenizer
)

  trainer = Seq2SeqTrainer(


In [5]:
# Обучение модели
trainer.train()

# Сохранение модели
trainer.save_model('./models/bart-test-case-generator')

Epoch,Training Loss,Validation Loss
1,0.2544,0.219952
2,0.2265,0.176177
3,0.1939,0.162411
4,0.1486,0.154066
5,0.1451,0.146769
6,0.1519,0.141722
7,0.1245,0.138588
8,0.117,0.136306
9,0.1228,0.135062
10,0.1302,0.134203


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


In [6]:
# Проверка результатов
def generate_steps(text):
    inputs = tokenizer(text, return_tensors='pt', max_length=32, truncation=True).to('cuda')
    output = model.generate(**inputs, max_length=256)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Тестирование на примерах
sample = test_dataset.select(range(5))
for example in sample:
    print(f"Scenario: {example['test_scenario']}")
    print(f"Generated steps: {generate_steps(example['test_scenario'])}")
    print(f"Actual steps: {example['test_steps']}\n{'='*50}")

Scenario: Verify course progress sync across devices when there are no interruptions.
Generated steps: 1. Log in to the online learning platform on Device A.
2. Navigate to a course on Device B.
3. Check if the course progress is synchronized across devices.
Actual steps: 1. Login to the online learning platform on Device A and start a course.
2. Verify that the progress is synced to Device B without any interruptions.
Scenario: Validate the accuracy of QoS metrics monitoring tool by monitoring network latency during peak hours.
Generated steps: 1. Simulate high network latency during peak hours.
2. Monitor network response time and latency.
Actual steps: 1. Simulate peak hours traffic on the network.
2. Monitor network latency for video streaming.
Scenario: Verify that a user can send a message to a project collaborator.
Generated steps: 1. Log in to the platform as a project collaborator.
2. Navigate to the messaging feature.
3. Select a specific project collaborator to send a messag